Skip to content

Commit

Permalink
add raw data and psd plotter function for OfflineStream
Browse files Browse the repository at this point in the history
  • Loading branch information
timonmerk committed Nov 15, 2023
1 parent df7558a commit 7249bfb
Showing 1 changed file with 67 additions and 9 deletions.
76 changes: 67 additions & 9 deletions py_neuromodulation/nm_stream_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,20 @@
import numpy as np
import pandas as pd

from py_neuromodulation import nm_generator, nm_IO, nm_stream_abc, nm_define_nmchannels
import mne

from py_neuromodulation import (
nm_generator,
nm_IO,
nm_stream_abc,
nm_define_nmchannels,
)

_PathLike = str | os.PathLike


class _OfflineStream(nm_stream_abc.PNStream):
"""Offline stream base class.
"""Offline stream base class.
This class can be inhereted for different types of offline streams, e.g. epoch-based or continuous.
Parameters
Expand Down Expand Up @@ -91,7 +98,7 @@ def _run_offline(
sample_add = self.sfreq / self.run_analysis.sfreq_features

offset_time = self.settings["segment_length_features_ms"]
#offset_start = np.ceil(offset_time / 1000 * self.sfreq).astype(int)
# offset_start = np.ceil(offset_time / 1000 * self.sfreq).astype(int)
offset_start = offset_time / 1000 * self.sfreq

cnt_samples = offset_start
Expand All @@ -100,7 +107,9 @@ def _run_offline(
data_batch = next(generator, None)
if data_batch is None:
break
feature_series = self.run_analysis.process(data_batch.astype(np.float64))
feature_series = self.run_analysis.process(
data_batch.astype(np.float64)
)
feature_series = self._add_timestamp(feature_series, cnt_samples)
features.append(feature_series)

Expand All @@ -116,6 +125,55 @@ def _run_offline(

return feature_df

def plot_raw_signal(
self,
sfreq: float,
data: np.array = None,
plot_time: bool = True,
plot_psd: bool = True,
) -> None:
"""Use MNE-RawArray Plot to investigate PSD or raw_signal plot.
Parameters
----------
sfreq : float
sampling frequency [Hz]
data : np.array, optional
data (n_channels, n_times), by default None
plot_time : bool, optional
mne.io.RawArray.plot(), by default True
plot_psd : bool, optional
mne.io.RawArray.plot(), by default True
Raises
------
ValueError
raise Exception when no data is passed
"""
if self.data is None and data is None:
raise ValueError("No data passed to plot_raw_signal function.")

if data is None and self.data is not None:
data = self.data

if self.nm_channels is not None:
ch_names = self.nm_channels["name"].to_list()
ch_types = self.nm_channels["type"].to_list()
else:
ch_names = [f"ch_{i}" for i in range(data.shape[0])]
ch_types = ["ecog" for i in range(data.shape[0])]

# create mne.RawArray
info = mne.create_info(
ch_names=ch_names, sfreq=self.sfreq, ch_types=ch_types
)
raw = mne.io.RawArray(data, info)
self.raw = raw
if plot_time:
raw.plot()
if plot_psd:
raw.plot_psd()


class Stream(_OfflineStream):
def __init__(
Expand All @@ -129,7 +187,7 @@ def __init__(
path_grids: _PathLike | None = None,
coord_names: list | None = None,
coord_list: list | None = None,
verbose: bool = True,
verbose: bool = True,
) -> None:
"""Stream initialization
Expand Down Expand Up @@ -158,7 +216,9 @@ def __init__(
"""

if nm_channels is None and data is not None:
nm_channels = nm_define_nmchannels.get_default_channels_from_data(data)
nm_channels = nm_define_nmchannels.get_default_channels_from_data(
data
)

if nm_channels is None and data is None:
raise ValueError(
Expand Down Expand Up @@ -210,8 +270,6 @@ def run(
elif self.data is not None:
data = self._handle_data(self.data)
elif self.data is None and data is None:
raise ValueError(
"No data passed to run function."
)
raise ValueError("No data passed to run function.")

return self._run_offline(data, out_path_root, folder_name)

0 comments on commit 7249bfb

Please sign in to comment.