From 7249bfbff35946633667f0b182c72f32e9b64e4e Mon Sep 17 00:00:00 2001 From: timonmerk Date: Wed, 15 Nov 2023 11:20:51 +0100 Subject: [PATCH] add raw data and psd plotter function for OfflineStream --- py_neuromodulation/nm_stream_offline.py | 76 ++++++++++++++++++++++--- 1 file changed, 67 insertions(+), 9 deletions(-) diff --git a/py_neuromodulation/nm_stream_offline.py b/py_neuromodulation/nm_stream_offline.py index 3bc345d6..6172411e 100644 --- a/py_neuromodulation/nm_stream_offline.py +++ b/py_neuromodulation/nm_stream_offline.py @@ -5,13 +5,20 @@ import numpy as np import pandas as pd -from py_neuromodulation import nm_generator, nm_IO, nm_stream_abc, nm_define_nmchannels +import mne + +from py_neuromodulation import ( + nm_generator, + nm_IO, + nm_stream_abc, + nm_define_nmchannels, +) _PathLike = str | os.PathLike class _OfflineStream(nm_stream_abc.PNStream): - """Offline stream base class. + """Offline stream base class. This class can be inhereted for different types of offline streams, e.g. epoch-based or continuous. Parameters @@ -91,7 +98,7 @@ def _run_offline( sample_add = self.sfreq / self.run_analysis.sfreq_features offset_time = self.settings["segment_length_features_ms"] - #offset_start = np.ceil(offset_time / 1000 * self.sfreq).astype(int) + # offset_start = np.ceil(offset_time / 1000 * self.sfreq).astype(int) offset_start = offset_time / 1000 * self.sfreq cnt_samples = offset_start @@ -100,7 +107,9 @@ def _run_offline( data_batch = next(generator, None) if data_batch is None: break - feature_series = self.run_analysis.process(data_batch.astype(np.float64)) + feature_series = self.run_analysis.process( + data_batch.astype(np.float64) + ) feature_series = self._add_timestamp(feature_series, cnt_samples) features.append(feature_series) @@ -116,6 +125,55 @@ def _run_offline( return feature_df + def plot_raw_signal( + self, + sfreq: float, + data: np.array = None, + plot_time: bool = True, + plot_psd: bool = True, + ) -> None: + """Use MNE-RawArray Plot to investigate PSD or raw_signal plot. + + Parameters + ---------- + sfreq : float + sampling frequency [Hz] + data : np.array, optional + data (n_channels, n_times), by default None + plot_time : bool, optional + mne.io.RawArray.plot(), by default True + plot_psd : bool, optional + mne.io.RawArray.plot(), by default True + + Raises + ------ + ValueError + raise Exception when no data is passed + """ + if self.data is None and data is None: + raise ValueError("No data passed to plot_raw_signal function.") + + if data is None and self.data is not None: + data = self.data + + if self.nm_channels is not None: + ch_names = self.nm_channels["name"].to_list() + ch_types = self.nm_channels["type"].to_list() + else: + ch_names = [f"ch_{i}" for i in range(data.shape[0])] + ch_types = ["ecog" for i in range(data.shape[0])] + + # create mne.RawArray + info = mne.create_info( + ch_names=ch_names, sfreq=self.sfreq, ch_types=ch_types + ) + raw = mne.io.RawArray(data, info) + self.raw = raw + if plot_time: + raw.plot() + if plot_psd: + raw.plot_psd() + class Stream(_OfflineStream): def __init__( @@ -129,7 +187,7 @@ def __init__( path_grids: _PathLike | None = None, coord_names: list | None = None, coord_list: list | None = None, - verbose: bool = True, + verbose: bool = True, ) -> None: """Stream initialization @@ -158,7 +216,9 @@ def __init__( """ if nm_channels is None and data is not None: - nm_channels = nm_define_nmchannels.get_default_channels_from_data(data) + nm_channels = nm_define_nmchannels.get_default_channels_from_data( + data + ) if nm_channels is None and data is None: raise ValueError( @@ -210,8 +270,6 @@ def run( elif self.data is not None: data = self._handle_data(self.data) elif self.data is None and data is None: - raise ValueError( - "No data passed to run function." - ) + raise ValueError("No data passed to run function.") return self._run_offline(data, out_path_root, folder_name)