Skip to content

Commit

Permalink
fix target estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
timonmerk committed Jan 30, 2024
1 parent 8df33fc commit f82b40b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
2 changes: 1 addition & 1 deletion examples/plot_1_example_BIDS.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
settings["features"]["fft"] = True
settings["features"]["bursts"] = True
settings["features"]["sharpwave_analysis"] = True
settings["features"]["coherence"] = True # True
settings["features"]["coherence"] = True
settings["coherence"]["channels"] = [["LFP_RIGHT_0", "ECOG_RIGHT_0"]]
settings["coherence"]["frequency_bands"] = ["high beta", "low gamma"]
settings["sharpwave_analysis_settings"]["estimator"]["mean"] = []
Expand Down
21 changes: 14 additions & 7 deletions py_neuromodulation/nm_stream_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _handle_data(self, data: np.ndarray | pd.DataFrame) -> np.ndarray:
f"Data columns: {names_data}, nm_channels.name: {names_data}."
)
return data.to_numpy()

def _check_settings_for_parallel(self):
"""Check specified settings and raise error if parallel processing is not possible.
Expand All @@ -122,12 +122,14 @@ def _check_settings_for_parallel(self):
"Parallel processing is not possible with burst estimation."
)


def _process_batch(self, data_batch, cnt_samples):
feature_series = self.run_analysis.process(
data_batch.astype(np.float64)
)
feature_series = self._add_timestamp(feature_series, cnt_samples)
feature_series = self._add_target(
feature_series=feature_series, data=data_batch
)
return feature_series

def _run_offline(
Expand Down Expand Up @@ -171,13 +173,16 @@ def _run_offline(
feature_series = self._add_timestamp(
feature_series, cnt_samples
)

feature_series = self._add_target(
feature_series=feature_series, data=data_batch
)

l_features.append(feature_series)

cnt_samples += sample_add
feature_df = pd.DataFrame(l_features)

feature_df = self._add_target(feature_series=feature_df, data=data)

self.save_after_stream(out_path_root, folder_name, feature_df)

return feature_df
Expand Down Expand Up @@ -313,7 +318,7 @@ def run(
out_path_root: _PathLike | None = None,
folder_name: str = "sub",
parallel: bool = False,
n_jobs: int = -2
n_jobs: int = -2,
) -> pd.DataFrame:
"""Call run function for offline stream.
Expand Down Expand Up @@ -341,8 +346,10 @@ def run(
data = self._handle_data(self.data)
elif self.data is None and data is None:
raise ValueError("No data passed to run function.")

if parallel is True:
self._check_settings_for_parallel()

return self._run_offline(data, out_path_root, folder_name, parallel=parallel, n_jobs=n_jobs)
return self._run_offline(
data, out_path_root, folder_name, parallel=parallel, n_jobs=n_jobs
)

0 comments on commit f82b40b

Please sign in to comment.