Skip to content

Commit

Permalink
Merge pull request #290 from neuromodulation/parallel_processing_pr
Browse files Browse the repository at this point in the history
Parallel processing pr
  • Loading branch information
timonmerk authored Jan 27, 2024
2 parents 702c3d3 + f349ec1 commit dcb981b
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 23 deletions.
3 changes: 2 additions & 1 deletion py_neuromodulation/nm_coherence.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,15 @@ def get_coh(self, features_compute, x, y):

class NM_Coherence(nm_features_abc.Feature):

coherence_objects: Iterable[CoherenceObject] = []


def __init__(
self, settings: dict, ch_names: Iterable[str], sfreq: float
) -> None:
self.s = settings
self.sfreq = sfreq
self.ch_names = ch_names
self.coherence_objects: Iterable[CoherenceObject] = []

for idx_coh in range(len(self.s["coherence"]["channels"])):
fband_names = self.s["coherence"]["frequency_bands"]
Expand Down
2 changes: 1 addition & 1 deletion py_neuromodulation/nm_mne_connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_epoched_data(
if epochs.events.shape[0] < 2:
raise Exception(
f"A minimum of 2 epochs is required for mne_connectivity,"
f" got only {epochs.events.shape[0]}. Increase settings['segment_length']"
f" got only {epochs.events.shape[0]}. Increase settings['segment_length_features_ms']"
)
return epochs

Expand Down
84 changes: 63 additions & 21 deletions py_neuromodulation/nm_stream_offline.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Module for offline data streams."""
import math
import os

from joblib import Parallel, delayed
import numpy as np
import pandas as pd

from itertools import count

import mne

from py_neuromodulation import (
Expand Down Expand Up @@ -67,7 +68,6 @@ def _add_timestamp(
Due to normalization run_analysis needs to keep track of the counted
samples. These are accessed here for time conversion.
"""
timestamp = cnt_samples * 1000 / self.sfreq
feature_series["time"] = cnt_samples * 1000 / self.sfreq

if self.verbose:
Expand Down Expand Up @@ -101,45 +101,82 @@ def _handle_data(self, data: np.ndarray | pd.DataFrame) -> np.ndarray:
f"Data columns: {names_data}, nm_channels.name: {names_data}."
)
return data.to_numpy()

def _check_settings_for_parallel(self):
"""Check specified settings and raise error if parallel processing is not possible.
Raises:
ValueError: depending on the settings, parallel processing is not possible
"""

if "raw_normalization" in self.settings["preprocessing"]:
raise ValueError(
"Parallel processing is not possible with raw_normalization normalization."
)
if self.settings["postprocessing"]["feature_normalization"] is True:
raise ValueError(
"Parallel processing is not possible with feature normalization."
)
if self.settings["features"]["bursts"] is True:
raise ValueError(
"Parallel processing is not possible with burst estimation."
)


def _process_batch(self, data_batch, cnt_samples):
feature_series = self.run_analysis.process(
data_batch.astype(np.float64)
)
feature_series = self._add_timestamp(feature_series, cnt_samples)
return feature_series

def _run_offline(
self,
data: np.ndarray,
out_path_root: _PathLike | None = None,
folder_name: str = "sub",
parallel: bool = False,
n_jobs: int = -2,
) -> pd.DataFrame:
generator = nm_generator.raw_data_generator(
data=data,
settings=self.settings,
sfreq=self.sfreq,
)
features = []

sample_add = self.sfreq / self.run_analysis.sfreq_features

offset_time = self.settings["segment_length_features_ms"]
# offset_start = np.ceil(offset_time / 1000 * self.sfreq).astype(int)
offset_start = offset_time / 1000 * self.sfreq

cnt_samples = offset_start

while True:
data_batch = next(generator, None)
if data_batch is None:
break
feature_series = self.run_analysis.process(
data_batch.astype(np.float64)
if parallel:
l_features = Parallel(n_jobs=n_jobs, verbose=10)(
delayed(self._process_batch)(data_batch, cnt_samples)
for data_batch, cnt_samples in zip(
generator, count(offset_start, sample_add)
)
)
feature_series = self._add_timestamp(feature_series, cnt_samples)
feature_series = self._add_target(feature_series, data_batch)

features.append(feature_series)

if self.model is not None:
prediction = self.model.predict(feature_series)
else:
l_features = []
cnt_samples = offset_start
while True:
data_batch = next(generator, None)
if data_batch is None:
break
feature_series = self.run_analysis.process(
data_batch.astype(np.float64)
)
feature_series = self._add_timestamp(
feature_series, cnt_samples
)
l_features.append(feature_series)

cnt_samples += sample_add
cnt_samples += sample_add
feature_df = pd.DataFrame(l_features)

feature_df = pd.DataFrame(features)
feature_df = self._add_target(feature_series=feature_df, data=data)

self.save_after_stream(out_path_root, folder_name, feature_df)

Expand Down Expand Up @@ -275,6 +312,8 @@ def run(
data: np.ndarray | pd.DataFrame = None,
out_path_root: _PathLike | None = None,
folder_name: str = "sub",
parallel: bool = False,
n_jobs: int = -2
) -> pd.DataFrame:
"""Call run function for offline stream.
Expand Down Expand Up @@ -302,5 +341,8 @@ def run(
data = self._handle_data(self.data)
elif self.data is None and data is None:
raise ValueError("No data passed to run function.")

if parallel is True:
self._check_settings_for_parallel()

return self._run_offline(data, out_path_root, folder_name)
return self._run_offline(data, out_path_root, folder_name, parallel=parallel, n_jobs=n_jobs)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ dependencies = [
"pybispectra>=1.1.0",
"pyparrm",
"pyarrow>=14.0.2",
"joblib>=1.3.2",
]

[project.optional-dependencies]
Expand Down

0 comments on commit dcb981b

Please sign in to comment.