From ce683a64d052752b8071e31f52304199680fbdf8 Mon Sep 17 00:00:00 2001 From: timonmerk Date: Thu, 23 Nov 2023 11:09:56 +0100 Subject: [PATCH] fix for new bispectrum version --- py_neuromodulation/nm_bispectra.py | 91 +++++++++++++++++++++--------- tests/test_bispectra.py | 9 ++- 2 files changed, 69 insertions(+), 31 deletions(-) diff --git a/py_neuromodulation/nm_bispectra.py b/py_neuromodulation/nm_bispectra.py index 522be503..70386b64 100644 --- a/py_neuromodulation/nm_bispectra.py +++ b/py_neuromodulation/nm_bispectra.py @@ -4,9 +4,11 @@ from py_neuromodulation import nm_features_abc -class Bispectra(nm_features_abc.Feature): - def __init__(self, settings: dict, ch_names: Iterable[str], sfreq: int | float) -> None: +class Bispectra(nm_features_abc.Feature): + def __init__( + self, settings: dict, ch_names: Iterable[str], sfreq: int | float + ) -> None: super().__init__(settings, ch_names, sfreq) self.sfreq = sfreq self.ch_names = ch_names @@ -40,18 +42,14 @@ def test_range(f_name, filter_range): ) test_range("f1s", s["bispectrum"]["f1s"]) - test_range("f2s", s["bispectrum"]["f2s"]) + test_range("f2s", s["bispectrum"]["f2s"]) - for feature_name, val in s["bispectrum"][ - "components" - ].items(): + for feature_name, val in s["bispectrum"]["components"].items(): assert isinstance( val, bool ), f"bispectrum component {feature_name} has to be of type bool, got {val}" - - for feature_name, val in s["bispectrum"][ - "bispectrum_features" - ].items(): + + for feature_name, val in s["bispectrum"]["bispectrum_features"].items(): assert isinstance( val, bool ), f"bispectrum feature {feature_name} has to be of type bool, got {val}" @@ -66,10 +64,15 @@ def test_range(f_name, filter_range): f"specified frequency_ranges_hz: {s['frequency_ranges_hz']}" ) - def compute_bs_features(self, spectrum_ch: np.array, features_compute: dict, ch_name: str, component: str, f_band: str) -> dict: - + def compute_bs_features( + self, + spectrum_ch: np.array, + features_compute: dict, + ch_name: str, + component: str, + f_band: str, + ) -> dict: for bispectrum_feature in self.s["bispectrum"]["bispectrum_features"]: - if bispectrum_feature == "mean": func = np.nanmean if bispectrum_feature == "sum": @@ -78,9 +81,25 @@ def compute_bs_features(self, spectrum_ch: np.array, features_compute: dict, ch_ func = np.nanvar if f_band is not None: - str_feature = "_".join([ch_name, "Bispectrum", component, bispectrum_feature, f_band]) + str_feature = "_".join( + [ + ch_name, + "Bispectrum", + component, + bispectrum_feature, + f_band, + ] + ) else: - str_feature = "_".join([ch_name, "Bispectrum", component, bispectrum_feature, "whole_fband_range"]) + str_feature = "_".join( + [ + ch_name, + "Bispectrum", + component, + bispectrum_feature, + "whole_fband_range", + ] + ) features_compute[str_feature] = func(spectrum_ch) @@ -89,13 +108,18 @@ def compute_bs_features(self, spectrum_ch: np.array, features_compute: dict, ch_ def calc_feature(self, data: np.array, features_compute: dict) -> dict: for ch_idx, ch_name in enumerate(self.ch_names): fft_coeffs, freqs = compute_fft( - data=np.expand_dims(data[ch_idx, :], axis=(0,1)), + data=np.expand_dims(data[ch_idx, :], axis=(0, 1)), sampling_freq=self.sfreq, n_points=data.shape[1], - verbose=False, + verbose=False, ) - f_spectrum_range = freqs[np.logical_and(freqs >= np.min([self.f1s, self.f2s]), freqs <= np.max([self.f1s, self.f2s]))] + f_spectrum_range = freqs[ + np.logical_and( + freqs >= np.min([self.f1s, self.f2s]), + freqs <= np.max([self.f1s, self.f2s]), + ) + ] waveshape = WaveShape( data=fft_coeffs, @@ -104,7 +128,9 @@ def calc_feature(self, data: np.array, features_compute: dict) -> dict: verbose=False, ) - waveshape.compute(f1s=tuple(self.f1s[0], self.f1s[-1]), f2s=tuple(self.f2s[0], self.f2s[-1])) + waveshape.compute( + f1s=(self.f1s[0], self.f1s[-1]), f2s=(self.f2s[0], self.f2s[-1]) + ) bispectrum = np.squeeze(waveshape.results._data) @@ -120,14 +146,23 @@ def calc_feature(self, data: np.array, features_compute: dict) -> dict: spectrum_ch = np.angle(bispectrum) for fb in self.s["bispectrum"]["frequency_bands"]: - range_ = (f_spectrum_range >= self.s["frequency_ranges_hz"][fb][0]) \ - & (f_spectrum_range <= self.s["frequency_ranges_hz"][fb][1]) - #waveshape.results.plot() - data_bs = spectrum_ch[range_, range_] - - features_compute = self.compute_bs_features(data_bs, features_compute, ch_name, component, fb) - - if self.s["bispectrum"]["compute_features_for_whole_fband_range"]: - features_compute = self.compute_bs_features(spectrum_ch, features_compute, ch_name, component, None) + range_ = ( + f_spectrum_range >= self.s["frequency_ranges_hz"][fb][0] + ) & ( + f_spectrum_range <= self.s["frequency_ranges_hz"][fb][1] + ) + # waveshape.results.plot() + data_bs = spectrum_ch[range_, range_] + + features_compute = self.compute_bs_features( + data_bs, features_compute, ch_name, component, fb + ) + + if self.s["bispectrum"][ + "compute_features_for_whole_fband_range" + ]: + features_compute = self.compute_bs_features( + spectrum_ch, features_compute, ch_name, component, None + ) return features_compute diff --git a/tests/test_bispectra.py b/tests/test_bispectra.py index 19ea73be..b54d74ad 100644 --- a/tests/test_bispectra.py +++ b/tests/test_bispectra.py @@ -9,11 +9,11 @@ nm_IO, nm_plots, nm_settings, - nm_stream_offline + nm_stream_offline, ) -def test_bispectrum(): +def test_bispectrum(): ( RUN_NAME, PATH_RUN, @@ -64,4 +64,7 @@ def test_bispectrum(): features = stream.run(np.expand_dims(data[3, :], axis=0)) - assert features["ECOG_RIGHT_1_Bispectrum_phase_mean_whole_fband_range"].sum() != 0 \ No newline at end of file + assert ( + features["ECOG_RIGHT_1_Bispectrum_phase_mean_whole_fband_range"].sum() + != 0 + )