From 6b85c3e25cfe690ba3d01bb377f133880383c62d Mon Sep 17 00:00:00 2001 From: timonmerk Date: Sun, 6 Oct 2024 23:50:54 +0200 Subject: [PATCH] refactor stream --- py_neuromodulation/__init__.py | 1 + py_neuromodulation/stream/__init__.py | 4 +- .../stream/data_generator_abc.py | 12 + py_neuromodulation/stream/data_processor.py | 11 +- py_neuromodulation/stream/generator.py | 53 --- .../{mnelsl_stream.py => mnelsl_generator.py} | 56 ++- .../stream/rawdata_generator.py | 159 +++++++ py_neuromodulation/stream/stream.py | 413 ++---------------- py_neuromodulation/utils/data_writer.py | 51 +++ py_neuromodulation/utils/io.py | 11 - tests/test_lsl_stream.py | 4 +- 11 files changed, 309 insertions(+), 466 deletions(-) create mode 100644 py_neuromodulation/stream/data_generator_abc.py delete mode 100644 py_neuromodulation/stream/generator.py rename py_neuromodulation/stream/{mnelsl_stream.py => mnelsl_generator.py} (74%) create mode 100644 py_neuromodulation/stream/rawdata_generator.py create mode 100644 py_neuromodulation/utils/data_writer.py diff --git a/py_neuromodulation/__init__.py b/py_neuromodulation/__init__.py index f1e245e3..79607cfb 100644 --- a/py_neuromodulation/__init__.py +++ b/py_neuromodulation/__init__.py @@ -68,6 +68,7 @@ from .utils import types from .utils import io +from .utils import data_writer from . import stream from . import analysis diff --git a/py_neuromodulation/stream/__init__.py b/py_neuromodulation/stream/__init__.py index 31f14b29..b9de0e63 100644 --- a/py_neuromodulation/stream/__init__.py +++ b/py_neuromodulation/stream/__init__.py @@ -1,5 +1,5 @@ -from .generator import RawDataGenerator +from .rawdata_generator import RawDataGenerator from .mnelsl_player import LSLOfflinePlayer -from .mnelsl_stream import LSLStream +from .mnelsl_generator import MNELSLGenerator from .stream import Stream from .settings import NMSettings diff --git a/py_neuromodulation/stream/data_generator_abc.py b/py_neuromodulation/stream/data_generator_abc.py new file mode 100644 index 00000000..4fbb233e --- /dev/null +++ b/py_neuromodulation/stream/data_generator_abc.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod +from typing import Tuple + + +class DataGeneratorABC(ABC): + + def __init__(self) -> Tuple[float, "pd.DataFrame"]: + pass + + @abstractmethod + def __next__(self) -> Tuple["np.ndarray", "np.ndarray"]: + pass diff --git a/py_neuromodulation/stream/data_processor.py b/py_neuromodulation/stream/data_processor.py index 3fc323ed..b5ca8903 100644 --- a/py_neuromodulation/stream/data_processor.py +++ b/py_neuromodulation/stream/data_processor.py @@ -55,6 +55,9 @@ def __init__( self.sfreq_raw: float = sfreq // 1 self.line_noise: float | None = line_noise self.path_grids: _PathLike | None = path_grids + if path_grids is None: + import py_neuromodulation as nm + path_grids = nm.PYNM_DIR #NOTE: could be optimized self.verbose: bool = verbose self.features_previous = None @@ -315,11 +318,3 @@ def save_settings(self, out_dir: _PathLike, prefix: str = "") -> None: def save_channels(self, out_dir: _PathLike, prefix: str) -> None: io.save_channels(self.channels, out_dir, prefix) - - def save_features( - self, - feature_arr: "pd.DataFrame", - out_dir: _PathLike = "", - prefix: str = "", - ) -> None: - io.save_features(feature_arr, out_dir, prefix) diff --git a/py_neuromodulation/stream/generator.py b/py_neuromodulation/stream/generator.py deleted file mode 100644 index bff9becb..00000000 --- a/py_neuromodulation/stream/generator.py +++ /dev/null @@ -1,53 +0,0 @@ -import numpy as np - - -class RawDataGenerator: - """ - This generator function mimics online data acquisition. - The data are iteratively sampled with settings.sampling_rate_features_hz - """ - - def __init__( - self, - data: np.ndarray, - sfreq: float, - sampling_rate_features_hz: float, - segment_length_features_ms: float, - ) -> None: - """ - Arguments - --------- - data (np array): shape (channels, time) - settings (settings.NMSettings): settings object - sfreq (float): sampling frequency of the data - - Returns - ------- - np.array: 1D array of time stamps - np.array: new batch for run function of full segment length shape - """ - self.batch_counter: int = 0 # counter for the batches - - self.data = data - self.sfreq = sfreq - # Width, in data points, of the moving window used to calculate features - self.segment_length = segment_length_features_ms / 1000 * sfreq - # Ratio of the sampling frequency of the input data to the sampling frequency - self.stride = sfreq / sampling_rate_features_hz - - def __iter__(self): - return self - - def __next__(self): - start = self.stride * self.batch_counter - end = start + self.segment_length - - self.batch_counter += 1 - - start_idx = int(start) - end_idx = int(end) - - if end_idx > self.data.shape[1]: - raise StopIteration - - return np.arange(start, end) / self.sfreq, self.data[:, start_idx:end_idx] diff --git a/py_neuromodulation/stream/mnelsl_stream.py b/py_neuromodulation/stream/mnelsl_generator.py similarity index 74% rename from py_neuromodulation/stream/mnelsl_stream.py rename to py_neuromodulation/stream/mnelsl_generator.py index 0d20329e..2d703311 100644 --- a/py_neuromodulation/stream/mnelsl_stream.py +++ b/py_neuromodulation/stream/mnelsl_generator.py @@ -1,21 +1,26 @@ from collections.abc import Iterator import time -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Tuple import numpy as np from py_neuromodulation.utils import logger from mne_lsl.lsl import resolve_streams import os +from .data_generator_abc import DataGeneratorABC if TYPE_CHECKING: from py_neuromodulation import NMSettings -class LSLStream: +class MNELSLGenerator(DataGeneratorABC): """ Class is used to create and connect to a LSL stream and pull data from it. """ - def __init__(self, settings: "NMSettings", stream_name: str | None = None) -> None: + def __init__(self, + segment_length_features_ms: float, + sampling_rate_features_hz: float, + stream_name: str | None = "example_stream", + ) -> None: """ Initialize the LSL stream. @@ -36,7 +41,6 @@ def __init__(self, settings: "NMSettings", stream_name: str | None = None) -> No self.stream: StreamLSL # self.keyboard_interrupt = False - self.settings = settings self._n_seconds_wait_before_disconnect = 3 try: if stream_name is None: @@ -55,18 +59,30 @@ def __init__(self, settings: "NMSettings", stream_name: str | None = None) -> No else: self.sinfo = self.stream.sinfo - self.winsize = settings.segment_length_features_ms / self.stream.sinfo.sfreq - self.sampling_interval = 1 / self.settings.sampling_rate_features_hz - - # If not running the generator when the escape key is pressed. - self.headless: bool = not os.environ.get("DISPLAY") - # if not self.headless: - # from py_neuromodulation.utils.keyboard import KeyboardListener - - # self.listener = KeyboardListener(("esc", self.set_keyboard_interrupt)) - # self.listener.start() - - def get_next_batch(self) -> Iterator[tuple[np.ndarray, np.ndarray]]: + self.winsize = segment_length_features_ms / self.stream.sinfo.sfreq + self.sampling_interval = 1 / sampling_rate_features_hz + self.channels = self.get_LSL_channels() + self.sfreq = self.stream.sinfo.sfreq + + def get_LSL_channels(self) -> "pd.DataFrame": + + from py_neuromodulation.utils import create_channels + ch_names = self.sinfo.get_channel_names() or [ + "ch" + str(i) for i in range(self.sinfo.n_channels) + ] + ch_types = self.sinfo.get_channel_types() or [ + "eeg" for i in range(self.sinfo.n_channels) + ] + return create_channels( + ch_names=ch_names, + ch_types=ch_types, + used_types=["eeg", "ecog", "dbs", "seeg"], + ) + + def __iter__(self): + return self + + def __next__(self) -> Iterator[tuple[np.ndarray, np.ndarray]]: self.last_time = time.time() check_data = None data = None @@ -112,11 +128,3 @@ def get_next_batch(self) -> Iterator[tuple[np.ndarray, np.ndarray]]: yield timestamp, data logger.info(f"Stream time: {timestamp[-1] - stream_start_time}") - - # if not self.headless and self.keyboard_interrupt: - # logger.info("Keyboard interrupt") - # self.listener.stop() - # self.stream.disconnect() - - # def set_keyboard_interrupt(self): - # self.keyboard_interrupt = True diff --git a/py_neuromodulation/stream/rawdata_generator.py b/py_neuromodulation/stream/rawdata_generator.py new file mode 100644 index 00000000..c19bed29 --- /dev/null +++ b/py_neuromodulation/stream/rawdata_generator.py @@ -0,0 +1,159 @@ +from py_neuromodulation.utils import logger +from py_neuromodulation.utils.io import MNE_FORMATS, read_mne_data, load_channels +from py_neuromodulation.utils.types import _PathLike +from py_neuromodulation.utils import create_channels +from .data_generator_abc import DataGeneratorABC +import numpy as np +import pandas as pd +from typing import Tuple + +class RawDataGenerator(DataGeneratorABC): + """ + This generator function mimics online data acquisition. + The data are iteratively sampled with settings.sampling_rate_features_hz + """ + + def __init__( + self, + data: "np.ndarray | pd.DataFrame | _PathLike | None", + sampling_rate_features_hz: float, + segment_length_features_ms: float, + channels: "pd.DataFrame | None", + sfreq: "float | None", + ) -> None: + """ + Arguments + --------- + data (np array): shape (channels, time) + settings (settings.NMSettings): settings object + sfreq (float): sampling frequency of the data + + Returns + ------- + np.array: 1D array of time stamps + np.array: new batch for run function of full segment length shape + """ + self.channels = channels + self.sfreq = sfreq + self.batch_counter: int = 0 # counter for the batches + self.target_idx_initialized: bool = False + + if isinstance(data, (np.ndarray, pd.DataFrame)): + logger.info(f"Loading data from {type(data).__name__}") + self.data = data + elif isinstance(self.data, _PathLike): + logger.info("Loading data from file") + filepath = Path(self.data) # type: ignore + ext = filepath.suffix + + if ext in MNE_FORMATS: + data, sfreq, ch_names, ch_types, bads = read_mne_data(filepath) + else: + raise ValueError(f"Unsupported file format: {ext}") + self.channels = create_channels( + ch_names=ch_names, + ch_types=ch_types, + used_types=["eeg", "ecog", "dbs", "seeg"], + bads=bads, + ) + + if sfreq is None: + raise ValueError( + "Sampling frequency not specified in file, please specify sfreq as a parameters" + ) + self.sfreq = sfreq + self.data = self._handle_data(data) + else: + raise ValueError( + "Data must be either a numpy array, a pandas DataFrame, or a path to an MNE supported file" + ) + self.sfreq = sfreq + # Width, in data points, of the moving window used to calculate features + self.segment_length = segment_length_features_ms / 1000 * sfreq + # Ratio of the sampling frequency of the input data to the sampling frequency + self.stride = sfreq / sampling_rate_features_hz + + self.channels = load_channels(channels) if channels is not None else None + + def _handle_data(self, data: "np.ndarray | pd.DataFrame") -> np.ndarray: + """_summary_ + + Args: + data (np.ndarray | pd.DataFrame): + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + np.ndarray: _description_ + """ + names_expected = self.channels["name"].to_list() + + if isinstance(data, np.ndarray): + if not len(names_expected) == data.shape[0]: + raise ValueError( + "If data is passed as an array, the first dimension must" + " match the number of channel names in `channels`.\n" + f" Number of data channels (data.shape[0]): {data.shape[0]}\n" + f' Length of channels["name"]: {len(names_expected)}.' + ) + return data + + names_data = data.columns.to_list() + + if not ( + len(names_expected) == len(names_data) + and sorted(names_expected) == sorted(names_data) + ): + raise ValueError( + "If data is passed as a DataFrame, the" + "column names must match the channel names in `channels`.\n" + f"Input dataframe column names: {names_data}\n" + f'Expected (from channels["name"]): : {names_expected}.' + ) + return data.to_numpy().transpose() + + def add_target(self, feature_dict: "pd.DataFrame", data_batch: np.array) -> None: + """Add target channels to feature series. + + Parameters + ---------- + feature_dict : pd.DataFra,e + + Returns + ------- + dict + feature dict with target channels added + """ + if not (isinstance(self.channels, pd.DataFrame)): + raise ValueError("Channels must be a pandas DataFrame") + + if self.channels["target"].sum() > 0: + if not self.target_idx_initialized: + self.target_indexes = self.channels[self.channels["target"] == 1].index + self.target_names = self.channels.loc[ + self.target_indexes, "name" + ].to_list() + self.target_idx_initialized = True + + for target_idx, target_name in zip(self.target_indexes, self.target_names): + feature_dict[target_name] = data_batch[target_idx, -1] + + return feature_dict + + def __iter__(self): + return self + + def __next__(self) -> Tuple[np.ndarray, np.ndarray]: + start = self.stride * self.batch_counter + end = start + self.segment_length + + self.batch_counter += 1 + + start_idx = int(start) + end_idx = int(end) + + if end_idx > self.data.shape[1]: + raise StopIteration + + return np.arange(start, end) / self.sfreq, self.data[:, start_idx:end_idx] diff --git a/py_neuromodulation/stream/stream.py b/py_neuromodulation/stream/stream.py index f60fec97..ef66cb69 100644 --- a/py_neuromodulation/stream/stream.py +++ b/py_neuromodulation/stream/stream.py @@ -1,28 +1,20 @@ """Module for generic and offline data streams.""" +import asyncio from typing import TYPE_CHECKING from collections.abc import Iterator import numpy as np -import pandas as pd -from pathlib import Path import multiprocessing as mp from contextlib import suppress -from py_neuromodulation.stream.data_processor import DataProcessor -from py_neuromodulation.utils.io import MNE_FORMATS, read_mne_data -from py_neuromodulation.utils.types import _PathLike -from py_neuromodulation.stream.settings import NMSettings from py_neuromodulation.features import USE_FREQ_RANGES -from py_neuromodulation.utils import ( - logger, - create_default_channels_from_data, - load_channels, - save_features, - create_channels, -) +from py_neuromodulation.utils.types import _PathLike +from py_neuromodulation.utils import logger +from py_neuromodulation.utils.data_writer import DataWriter from py_neuromodulation.gui.backend.app_socket import WebSocketManager -from py_neuromodulation import PYNM_DIR +from py_neuromodulation.stream.rawdata_generator import RawDataGenerator +from py_neuromodulation.stream.data_processor import DataProcessor if TYPE_CHECKING: import pandas as pd @@ -39,296 +31,59 @@ class Stream: def __init__( self, - data: "np.ndarray | pd.DataFrame | _PathLike | None" = None, - sfreq: float | None = None, - experiment_name: str = "sub", - channels: "pd.DataFrame | _PathLike | None" = None, - is_stream_lsl: bool = False, - stream_lsl_name: str | None = None, - settings: NMSettings | _PathLike | None = None, - line_noise: float | None = 50, - sampling_rate_features_hz: float | None = None, - path_grids: _PathLike | None = None, - coord_names: list | None = None, - coord_list: list | None = None, verbose: bool = True, ) -> None: - """Stream initialization - - Parameters - ---------- - sfreq : float - sampling frequency of data in Hertz - channels : pd.DataFrame | _PathLike - parametrization of channels (see define_channels.py for initialization) - data : np.ndarray | pd.DataFrame | None, optional - data to be streamed with shape (n_channels, n_time), by default None - settings : NMSettings | _PathLike | None, optional - Initialized settings.NMSettings object, by default the py_neuromodulation/settings.yaml are read - and passed into a settings object - line_noise : float | None, optional - line noise, by default 50 - sampling_rate_features_hz : float | None, optional - feature sampling rate, by default None - path_grids : _PathLike | None, optional - path to grid_cortex.tsv and/or gird_subcortex.tsv, by default Non - coord_names : list | None, optional - coordinate name in the form [coord_1_name, coord_2_name, etc], by default None - coord_list : list | None, optional - coordinates in the form [[coord_1_x, coord_1_y, coord_1_z], [coord_2_x, coord_2_y, coord_2_z],], by default None - verbose : bool, optional - print out stream computation time information, by default True - """ - # Input params - self.path_grids = path_grids self.verbose = verbose - self.line_noise = line_noise - self.coord_names = coord_names - self.coord_list = coord_list - self.experiment_name = experiment_name - self.data = data - self.settings: NMSettings = NMSettings.load(settings) - self.is_stream_lsl = is_stream_lsl - self.stream_lsl_name = stream_lsl_name - - self.sess_right = None - self.projection = None - self.model = None - - if sampling_rate_features_hz is not None: - self.settings.sampling_rate_features_hz = sampling_rate_features_hz - - if path_grids is None: - path_grids = PYNM_DIR - - # Set up some flags for stream processing later self.is_running = False - self.target_idx_initialized: bool = False - - # Validate input depending on stream type and initialize stream - self.generator: Iterator - - if self.is_stream_lsl: - from py_neuromodulation.stream.mnelsl_stream import LSLStream - - if self.stream_lsl_name is None: - logger.info( - "No stream name specified. Will connect to the first available stream if it exists." - ) - - print(self.stream_lsl_name) - self.lsl_stream = LSLStream( - settings=self.settings, stream_name=self.stream_lsl_name - ) - - sinfo = self.lsl_stream.sinfo - - # If no sampling frequency is specified in the stream, try to get it from the passed parameters - if sinfo.sfreq is None: - logger.info("No sampling frequency specified in LSL stream") - if sfreq is not None: - logger.info("Using sampling frequency passed to Stream constructor") - else: - raise ValueError( - "No sampling frequency specified in stream and no sampling frequency passed to Stream constructor" - ) - else: - if sfreq is not None != sinfo.sfreq: - logger.info( - "Sampling frequency of the LSL stream does not match the passed sampling frequency." - ) - logger.info("Using sampling frequency of the LSL stream") - self.sfreq = sinfo.sfreq - - # TONI: should we try to get channels from the passed "channels" parameter before generating default? - - # Try to get channel names and types from the stream, if not generate default - ch_names = sinfo.get_channel_names() or [ - "ch" + str(i) for i in range(sinfo.n_channels) - ] - ch_types = sinfo.get_channel_types() or [ - "eeg" for i in range(sinfo.n_channels) - ] - self.channels = create_channels( - ch_names=ch_names, - ch_types=ch_types, - used_types=["eeg", "ecog", "dbs", "seeg"], - ) + - self.generator = self.lsl_stream.get_next_batch() - - else: # Data passed as array, dataframe or path to file - if data is None: - raise ValueError( - "If is_stream_lsl is False, data must be passed to the Stream constructor" - ) - - # If channels passed to constructor, try to load them - self.channels = load_channels(channels) if channels is not None else None - - if isinstance(self.data, (np.ndarray, pd.DataFrame)): - logger.info(f"Loading data from {type(data).__name__}") - - if sfreq is None: - raise ValueError( - "sfreq must be specified when passing data as an array or dataframe" - ) - - self.sfreq = sfreq - - if self.channels is None: - self.channels = create_default_channels_from_data(self.data) - - self.data = self._handle_data(self.data) - - elif isinstance(self.data, _PathLike): - # If data is a path, try to load it as an MNE supported file - logger.info("Loading data from file") - filepath = Path(self.data) # type: ignore - ext = filepath.suffix - - if ext in MNE_FORMATS: - data, sfreq, ch_names, ch_types, bads = read_mne_data(filepath) - else: - raise ValueError(f"Unsupported file format: {ext}") - - if sfreq is None: - raise ValueError( - "Sampling frequency not specified in file, please specify sfreq as a parameters" - ) - - self.sfreq = sfreq - - self.channels = create_channels( - ch_names=ch_names, - ch_types=ch_types, - used_types=["eeg", "ecog", "dbs", "seeg"], - bads=bads, - ) - - # _handle_data requires the channels to be set - self.data = self._handle_data(data) - - else: - raise ValueError( - "Data must be either a numpy array, a pandas DataFrame, or a path to an MNE supported file" - ) - - from py_neuromodulation.stream.generator import RawDataGenerator - - self.generator: Iterator = RawDataGenerator( - self.data, - self.sfreq, - self.settings.sampling_rate_features_hz, - self.settings.segment_length_features_ms, - ) - - self._initialize_data_processor() - - def _add_target(self, feature_dict: dict, data: np.ndarray) -> None: - """Add target channels to feature series. - - Parameters - ---------- - feature_dict : dict - data : np.ndarray - Raw data with shape (n_channels, n_samples). - Channels not usd for feature computation are also included - - Returns - ------- - dict - feature dict with target channels added - """ - if not (isinstance(self.channels, pd.DataFrame)): - raise ValueError("Channels must be a pandas DataFrame") - - if self.channels["target"].sum() > 0: - if not self.target_idx_initialized: - self.target_indexes = self.channels[self.channels["target"] == 1].index - self.target_names = self.channels.loc[ - self.target_indexes, "name" - ].to_list() - self.target_idx_initialized = True - - for target_idx, target_name in zip(self.target_indexes, self.target_names): - feature_dict[target_name] = data[target_idx, -1] - - def run( + async def run( self, - out_dir: _PathLike = "", - save_csv: bool = False, - save_interval: int = 10, - return_df: bool = True, - stream_handling_queue: "mp.Queue | None" = None, + data_processor: DataProcessor | None = None, + data_generator : Iterator | None = None, + data_writer: DataWriter | None = None, + stream_handling_queue: asyncio.Queue | None = None, websocket_featues: WebSocketManager | None = None, ): + self.data_processor = data_processor # Check that at least one channel is selected for analysis - if self.channels.query("used == 1 and target == 0").shape[0] == 0: + if self.data_processor.channels.query("used == 1 and target == 0").shape[0] == 0: raise ValueError( "No channels selected for analysis that have column 'used' = 1 and 'target' = 0. Please check your channels" ) # If features that use frequency ranges are on, test them against nyquist frequency need_nyquist_check = any( - (f in USE_FREQ_RANGES for f in self.settings.features.get_enabled()) + (f in USE_FREQ_RANGES for f in self.data_processor.settings.features.get_enabled()) ) if need_nyquist_check: assert all( - fb.frequency_high_hz < self.sfreq / 2 - for fb in self.settings.frequency_ranges_hz.values() + fb.frequency_high_hz < self.data_processor.sfreq_raw / 2 + for fb in self.data_processor.settings.frequency_ranges_hz.values() ), ( "If a feature that uses frequency ranges is selected, " "the frequency band ranges need to be smaller than the nyquist frequency.\n" - f"Got sfreq = {self.sfreq} and fband ranges:\n {self.settings.frequency_ranges_hz}" + f"Got sfreq = {self.data_processor.sfreq_raw} and fband ranges:\n {self.data_processor.settings.frequency_ranges_hz}" ) self.stream_handling_queue = stream_handling_queue - # self.feature_queue = feature_queue - self.save_csv = save_csv - self.save_interval = save_interval - self.return_df = return_df - - # Generate output dirs - self.out_dir_root = Path.cwd() if not out_dir else Path(out_dir) - self.out_dir = self.out_dir_root / self.experiment_name - # TONI: Need better default experiment name - - self.out_dir.mkdir(parents=True, exist_ok=True) - - # Open database connection - # TONI: we should give the user control over the save format - from py_neuromodulation.utils.database import NMDatabase - - self.db = NMDatabase(self.experiment_name, out_dir) # Create output database - - self.batch_count: int = 0 # Keep track of the number of batches processed - - # Reinitialize the data processor in case the nm_channels or nm_settings changed between runs of the same Stream - self._initialize_data_processor() - - logger.log_to_file(out_dir) - - # # Initialize mp.Pool for multiprocessing - # self.pool = mp.Pool(processes=self.settings.n_jobs) - # # Set up shared memory for multiprocessing - # self.shared_memory = mp.Array(ctypes.c_double, self.settings.n_jobs * self.settings.n_jobs) - # # Set up multiprocessing semaphores - # self.semaphore = mp.Semaphore(self.settings.n_jobs) + self.is_running = False + self.is_lslstream = type(data_generator) != RawDataGenerator prev_batch_end = 0 - for timestamps, data_batch in self.generator: + for timestamps, data_batch in data_generator: self.is_running = True if self.stream_handling_queue is not None: + await asyncio.sleep(0.001) if not self.stream_handling_queue.empty(): - value = self.stream_handling_queue.get() - if value == "stop": + stop_signal = await asyncio.wait_for(self.stream_handling_queue.get(), timeout=0.01) + if stop_signal == "stop": break if data_batch is None: break - feature_dict = self.data_processor.process(data_batch) + feature_dict = data_processor.process(data_batch) this_batch_end = timestamps[-1] batch_length = this_batch_end - prev_batch_end @@ -338,7 +93,7 @@ def run( feature_dict["time"] = ( batch_length - if self.is_stream_lsl + if self.is_lslstream else np.ceil(this_batch_end * 1000 + 1) ) @@ -347,39 +102,23 @@ def run( if self.verbose: logger.info("Time: %.2f", feature_dict["time"] / 1000) - self._add_target(feature_dict, data_batch) + feature_dict = data_generator.add_target(feature_dict, data_batch) - # We should ensure that feature output is always either float64 or None and remove this with suppress(TypeError): # Need this because some features output None for key, value in feature_dict.items(): feature_dict[key] = np.float64(value) - self.db.insert_data(feature_dict) - - # if self.feature_queue is not None: - # self.feature_queue.put(feature_dict) - - # if websocket_features is not None: - # logger.info("Sending message to Websocket") - # await websocket_featues.send_message(feature_dict) - - self.batch_count += 1 - if self.batch_count % self.save_interval == 0: - self.db.commit() + data_writer.write_data(feature_dict) - self.db.commit() # Save last batches + if websocket_featues is not None: + await websocket_featues.send_cbor(feature_dict) - # If save_csv is False, still save the first row to get the column names - feature_df: "pd.DataFrame" = ( - self.db.fetch_all() if (self.save_csv or self.return_df) else self.db.head() - ) - - self.db.close() # Close the database connection - - self._save_after_stream(feature_arr=feature_df) + feature_df = data_writer.get_features() + data_writer.save_csv_features(feature_df) + self._save_sidecars_after_stream(data_writer.out_dir, data_writer.experiment_name) self.is_running = False - return feature_df # TONI: Not sure if this makes sense anymore + return feature_df def plot_raw_signal( self, @@ -445,83 +184,25 @@ def plot_raw_signal( if plot_psd: raw.compute_psd().plot() - def _handle_data(self, data: "np.ndarray | pd.DataFrame") -> np.ndarray: - """_summary_ - - Args: - data (np.ndarray | pd.DataFrame): - Raises: - ValueError: _description_ - ValueError: _description_ - - Returns: - np.ndarray: _description_ - """ - names_expected = self.channels["name"].to_list() - - if isinstance(data, np.ndarray): - if not len(names_expected) == data.shape[0]: - raise ValueError( - "If data is passed as an array, the first dimension must" - " match the number of channel names in `channels`.\n" - f" Number of data channels (data.shape[0]): {data.shape[0]}\n" - f' Length of channels["name"]: {len(names_expected)}.' - ) - return data - - names_data = data.columns.to_list() - - if not ( - len(names_expected) == len(names_data) - and sorted(names_expected) == sorted(names_data) - ): - raise ValueError( - "If data is passed as a DataFrame, the" - "column names must match the channel names in `channels`.\n" - f"Input dataframe column names: {names_data}\n" - f'Expected (from channels["name"]): : {names_expected}.' - ) - return data.to_numpy().transpose() - - def _initialize_data_processor(self) -> None: - self.data_processor = DataProcessor( - sfreq=self.sfreq, - settings=self.settings, - channels=self.channels, - path_grids=self.path_grids, - coord_names=self.coord_names, - coord_list=self.coord_list, - line_noise=self.line_noise, - verbose=self.verbose, - ) - - def _save_after_stream( - self, - feature_arr: "pd.DataFrame | None" = None, - ) -> None: - """Save features, settings, nm_channels and sidecar after run""" - self._save_sidecar() - if feature_arr is not None: - self._save_features(feature_arr) - self._save_settings() - self._save_channels() - - def _save_features( + def _save_sidecars_after_stream( self, - feature_arr: "pd.DataFrame", + out_dir: _PathLike, + experiment_name: str = "experiment" ) -> None: - save_features(feature_arr, self.out_dir, self.experiment_name) + """Save settings, nm_channels and sidecar after run""" + self._save_sidecar(out_dir, experiment_name) + self._save_settings(out_dir, experiment_name) + self._save_channels(out_dir, experiment_name) - def _save_channels(self) -> None: - self.data_processor.save_channels(self.out_dir, self.experiment_name) + def _save_channels(self, out_dir, experiment_name) -> None: + self.data_processor.save_channels(out_dir, experiment_name) - def _save_settings(self) -> None: - self.data_processor.save_settings(self.out_dir, self.experiment_name) + def _save_settings(self, out_dir, experiment_name) -> None: + self.data_processor.save_settings(out_dir, experiment_name) - def _save_sidecar(self) -> None: + def _save_sidecar(self, out_dir, experiment_name) -> None: """Save sidecar incduing fs, coords, sess_right to out_path_root and subfolder 'folder_name'""" - additional_args = {"sess_right": self.sess_right} self.data_processor.save_sidecar( - self.out_dir, self.experiment_name, additional_args + out_dir, experiment_name ) diff --git a/py_neuromodulation/utils/data_writer.py b/py_neuromodulation/utils/data_writer.py new file mode 100644 index 00000000..7fdb3d3c --- /dev/null +++ b/py_neuromodulation/utils/data_writer.py @@ -0,0 +1,51 @@ +from py_neuromodulation.utils.types import _PathLike +from py_neuromodulation.utils import logger, io +from pathlib import Path + +class DataWriter: + + def __init__(self, out_dir: _PathLike = "", save_csv: bool = False, + save_interval: int = 10, experiment_name: str = "experiment"): + + self.batch_count: int = 0 + self.save_interval: int = save_interval + self.save_csv: bool = save_csv + self.out_dir: _PathLike = out_dir + self.experiment_name: str = experiment_name + + self.out_dir_root = Path.cwd() if not out_dir else Path(out_dir) + self.out_dir = self.out_dir_root / self.experiment_name + self.out_dir.mkdir(parents=True, exist_ok=True) + + from py_neuromodulation.utils.database import NMDatabase + self.db = NMDatabase(self.experiment_name, out_dir) + + logger.log_to_file(out_dir) + + + def write_data(self, feature_dict): + + self.db.insert_data(feature_dict) + self.batch_count += 1 + if self.batch_count % self.save_interval == 0: + self.db.commit() + + def get_features(self, return_df: bool = False): + + self.db.commit() # Save last batches + + # If save_csv is False, still save the first row to get the column names + feature_df = ( + self.db.fetch_all() if (self.save_csv or return_df) else self.db.head() + ) + + self.db.close() + return feature_df + + def save_csv_features( + self, + df_features: "pd.DataFrame" + ) -> None: + filename = f"{self.experiment_name}_FEATURES.csv" if self.experiment_name else "_FEATURES.csv" + io.write_csv(df_features, self.out_dir / filename) + logger.info(f"{filename} saved to {str(self.out_dir)}") diff --git a/py_neuromodulation/utils/io.py b/py_neuromodulation/utils/io.py index 35289e59..68c2c792 100644 --- a/py_neuromodulation/utils/io.py +++ b/py_neuromodulation/utils/io.py @@ -258,17 +258,6 @@ def save_channels( logger.info(f"{filename} saved to {out_dir}") -def save_features( - df_features: "pd.DataFrame", - out_dir: _PathLike = "", - prefix: str = "", -) -> None: - out_dir = Path.cwd() if not out_dir else Path(out_dir) - filename = f"{prefix}_FEATURES.csv" if prefix else "_FEATURES.csv" - write_csv(df_features, out_dir / filename) - logger.info(f"{filename} saved to {str(out_dir)}") - - def save_sidecar( sidecar: dict, out_dir: _PathLike = "", diff --git a/tests/test_lsl_stream.py b/tests/test_lsl_stream.py index b40859ce..068b9ee9 100644 --- a/tests/test_lsl_stream.py +++ b/tests/test_lsl_stream.py @@ -8,12 +8,12 @@ @pytest.mark.parametrize("setup_lsl_player", ["search"], indirect=True) def test_lsl_stream_search(setup_lsl_player): - from py_neuromodulation.stream import mnelsl_stream + from py_neuromodulation.stream import mnelsl_generator """ Test if the lsl stream search can find any streams after starting a player.""" player = setup_lsl_player player.start_player() - streams = mnelsl_stream.resolve_streams() + streams = mnelsl_generator.resolve_streams() assert len(streams) != 0, "No streams found in search"