Skip to content

Commit

Permalink
refactor stream
Browse files Browse the repository at this point in the history
  • Loading branch information
timonmerk committed Oct 6, 2024
1 parent 2ab0bf8 commit 6b85c3e
Show file tree
Hide file tree
Showing 11 changed files with 309 additions and 466 deletions.
1 change: 1 addition & 0 deletions py_neuromodulation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@

from .utils import types
from .utils import io
from .utils import data_writer

from . import stream
from . import analysis
Expand Down
4 changes: 2 additions & 2 deletions py_neuromodulation/stream/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .generator import RawDataGenerator
from .rawdata_generator import RawDataGenerator
from .mnelsl_player import LSLOfflinePlayer
from .mnelsl_stream import LSLStream
from .mnelsl_generator import MNELSLGenerator
from .stream import Stream
from .settings import NMSettings
12 changes: 12 additions & 0 deletions py_neuromodulation/stream/data_generator_abc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from abc import ABC, abstractmethod
from typing import Tuple


class DataGeneratorABC(ABC):

def __init__(self) -> Tuple[float, "pd.DataFrame"]:
pass

@abstractmethod
def __next__(self) -> Tuple["np.ndarray", "np.ndarray"]:
pass
11 changes: 3 additions & 8 deletions py_neuromodulation/stream/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def __init__(
self.sfreq_raw: float = sfreq // 1
self.line_noise: float | None = line_noise
self.path_grids: _PathLike | None = path_grids
if path_grids is None:
import py_neuromodulation as nm
path_grids = nm.PYNM_DIR #NOTE: could be optimized
self.verbose: bool = verbose

self.features_previous = None
Expand Down Expand Up @@ -315,11 +318,3 @@ def save_settings(self, out_dir: _PathLike, prefix: str = "") -> None:

def save_channels(self, out_dir: _PathLike, prefix: str) -> None:
io.save_channels(self.channels, out_dir, prefix)

def save_features(
self,
feature_arr: "pd.DataFrame",
out_dir: _PathLike = "",
prefix: str = "",
) -> None:
io.save_features(feature_arr, out_dir, prefix)
53 changes: 0 additions & 53 deletions py_neuromodulation/stream/generator.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
from collections.abc import Iterator
import time
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Tuple
import numpy as np
from py_neuromodulation.utils import logger
from mne_lsl.lsl import resolve_streams
import os
from .data_generator_abc import DataGeneratorABC

if TYPE_CHECKING:
from py_neuromodulation import NMSettings


class LSLStream:
class MNELSLGenerator(DataGeneratorABC):
"""
Class is used to create and connect to a LSL stream and pull data from it.
"""

def __init__(self, settings: "NMSettings", stream_name: str | None = None) -> None:
def __init__(self,
segment_length_features_ms: float,
sampling_rate_features_hz: float,
stream_name: str | None = "example_stream",
) -> None:
"""
Initialize the LSL stream.
Expand All @@ -36,7 +41,6 @@ def __init__(self, settings: "NMSettings", stream_name: str | None = None) -> No
self.stream: StreamLSL
# self.keyboard_interrupt = False

self.settings = settings
self._n_seconds_wait_before_disconnect = 3
try:
if stream_name is None:
Expand All @@ -55,18 +59,30 @@ def __init__(self, settings: "NMSettings", stream_name: str | None = None) -> No
else:
self.sinfo = self.stream.sinfo

self.winsize = settings.segment_length_features_ms / self.stream.sinfo.sfreq
self.sampling_interval = 1 / self.settings.sampling_rate_features_hz

# If not running the generator when the escape key is pressed.
self.headless: bool = not os.environ.get("DISPLAY")
# if not self.headless:
# from py_neuromodulation.utils.keyboard import KeyboardListener

# self.listener = KeyboardListener(("esc", self.set_keyboard_interrupt))
# self.listener.start()

def get_next_batch(self) -> Iterator[tuple[np.ndarray, np.ndarray]]:
self.winsize = segment_length_features_ms / self.stream.sinfo.sfreq
self.sampling_interval = 1 / sampling_rate_features_hz
self.channels = self.get_LSL_channels()
self.sfreq = self.stream.sinfo.sfreq

def get_LSL_channels(self) -> "pd.DataFrame":

from py_neuromodulation.utils import create_channels
ch_names = self.sinfo.get_channel_names() or [
"ch" + str(i) for i in range(self.sinfo.n_channels)
]
ch_types = self.sinfo.get_channel_types() or [
"eeg" for i in range(self.sinfo.n_channels)
]
return create_channels(
ch_names=ch_names,
ch_types=ch_types,
used_types=["eeg", "ecog", "dbs", "seeg"],
)

def __iter__(self):
return self

def __next__(self) -> Iterator[tuple[np.ndarray, np.ndarray]]:
self.last_time = time.time()
check_data = None
data = None
Expand Down Expand Up @@ -112,11 +128,3 @@ def get_next_batch(self) -> Iterator[tuple[np.ndarray, np.ndarray]]:
yield timestamp, data

logger.info(f"Stream time: {timestamp[-1] - stream_start_time}")

# if not self.headless and self.keyboard_interrupt:
# logger.info("Keyboard interrupt")
# self.listener.stop()
# self.stream.disconnect()

# def set_keyboard_interrupt(self):
# self.keyboard_interrupt = True
159 changes: 159 additions & 0 deletions py_neuromodulation/stream/rawdata_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from py_neuromodulation.utils import logger
from py_neuromodulation.utils.io import MNE_FORMATS, read_mne_data, load_channels
from py_neuromodulation.utils.types import _PathLike
from py_neuromodulation.utils import create_channels
from .data_generator_abc import DataGeneratorABC
import numpy as np
import pandas as pd
from typing import Tuple

class RawDataGenerator(DataGeneratorABC):
"""
This generator function mimics online data acquisition.
The data are iteratively sampled with settings.sampling_rate_features_hz
"""

def __init__(
self,
data: "np.ndarray | pd.DataFrame | _PathLike | None",
sampling_rate_features_hz: float,
segment_length_features_ms: float,
channels: "pd.DataFrame | None",
sfreq: "float | None",
) -> None:
"""
Arguments
---------
data (np array): shape (channels, time)
settings (settings.NMSettings): settings object
sfreq (float): sampling frequency of the data
Returns
-------
np.array: 1D array of time stamps
np.array: new batch for run function of full segment length shape
"""
self.channels = channels
self.sfreq = sfreq
self.batch_counter: int = 0 # counter for the batches
self.target_idx_initialized: bool = False

if isinstance(data, (np.ndarray, pd.DataFrame)):
logger.info(f"Loading data from {type(data).__name__}")
self.data = data
elif isinstance(self.data, _PathLike):
logger.info("Loading data from file")
filepath = Path(self.data) # type: ignore
ext = filepath.suffix

if ext in MNE_FORMATS:
data, sfreq, ch_names, ch_types, bads = read_mne_data(filepath)
else:
raise ValueError(f"Unsupported file format: {ext}")
self.channels = create_channels(
ch_names=ch_names,
ch_types=ch_types,
used_types=["eeg", "ecog", "dbs", "seeg"],
bads=bads,
)

if sfreq is None:
raise ValueError(
"Sampling frequency not specified in file, please specify sfreq as a parameters"
)
self.sfreq = sfreq
self.data = self._handle_data(data)
else:
raise ValueError(
"Data must be either a numpy array, a pandas DataFrame, or a path to an MNE supported file"
)
self.sfreq = sfreq
# Width, in data points, of the moving window used to calculate features
self.segment_length = segment_length_features_ms / 1000 * sfreq
# Ratio of the sampling frequency of the input data to the sampling frequency
self.stride = sfreq / sampling_rate_features_hz

self.channels = load_channels(channels) if channels is not None else None

def _handle_data(self, data: "np.ndarray | pd.DataFrame") -> np.ndarray:
"""_summary_
Args:
data (np.ndarray | pd.DataFrame):
Raises:
ValueError: _description_
ValueError: _description_
Returns:
np.ndarray: _description_
"""
names_expected = self.channels["name"].to_list()

if isinstance(data, np.ndarray):
if not len(names_expected) == data.shape[0]:
raise ValueError(
"If data is passed as an array, the first dimension must"
" match the number of channel names in `channels`.\n"
f" Number of data channels (data.shape[0]): {data.shape[0]}\n"
f' Length of channels["name"]: {len(names_expected)}.'
)
return data

names_data = data.columns.to_list()

if not (
len(names_expected) == len(names_data)
and sorted(names_expected) == sorted(names_data)
):
raise ValueError(
"If data is passed as a DataFrame, the"
"column names must match the channel names in `channels`.\n"
f"Input dataframe column names: {names_data}\n"
f'Expected (from channels["name"]): : {names_expected}.'
)
return data.to_numpy().transpose()

def add_target(self, feature_dict: "pd.DataFrame", data_batch: np.array) -> None:
"""Add target channels to feature series.
Parameters
----------
feature_dict : pd.DataFra,e
Returns
-------
dict
feature dict with target channels added
"""
if not (isinstance(self.channels, pd.DataFrame)):
raise ValueError("Channels must be a pandas DataFrame")

if self.channels["target"].sum() > 0:
if not self.target_idx_initialized:
self.target_indexes = self.channels[self.channels["target"] == 1].index
self.target_names = self.channels.loc[
self.target_indexes, "name"
].to_list()
self.target_idx_initialized = True

for target_idx, target_name in zip(self.target_indexes, self.target_names):
feature_dict[target_name] = data_batch[target_idx, -1]

return feature_dict

def __iter__(self):
return self

def __next__(self) -> Tuple[np.ndarray, np.ndarray]:
start = self.stride * self.batch_counter
end = start + self.segment_length

self.batch_counter += 1

start_idx = int(start)
end_idx = int(end)

if end_idx > self.data.shape[1]:
raise StopIteration

return np.arange(start, end) / self.sfreq, self.data[:, start_idx:end_idx]
Loading

0 comments on commit 6b85c3e

Please sign in to comment.