Skip to content

Commit

Permalink
fix for new bispectrum version
Browse files Browse the repository at this point in the history
  • Loading branch information
timonmerk committed Nov 23, 2023
1 parent b03816c commit ce683a6
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 31 deletions.
91 changes: 63 additions & 28 deletions py_neuromodulation/nm_bispectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

from py_neuromodulation import nm_features_abc

class Bispectra(nm_features_abc.Feature):

def __init__(self, settings: dict, ch_names: Iterable[str], sfreq: int | float) -> None:
class Bispectra(nm_features_abc.Feature):
def __init__(
self, settings: dict, ch_names: Iterable[str], sfreq: int | float
) -> None:
super().__init__(settings, ch_names, sfreq)
self.sfreq = sfreq
self.ch_names = ch_names
Expand Down Expand Up @@ -40,18 +42,14 @@ def test_range(f_name, filter_range):
)

test_range("f1s", s["bispectrum"]["f1s"])
test_range("f2s", s["bispectrum"]["f2s"])
test_range("f2s", s["bispectrum"]["f2s"])

for feature_name, val in s["bispectrum"][
"components"
].items():
for feature_name, val in s["bispectrum"]["components"].items():
assert isinstance(
val, bool
), f"bispectrum component {feature_name} has to be of type bool, got {val}"

for feature_name, val in s["bispectrum"][
"bispectrum_features"
].items():

for feature_name, val in s["bispectrum"]["bispectrum_features"].items():
assert isinstance(
val, bool
), f"bispectrum feature {feature_name} has to be of type bool, got {val}"
Expand All @@ -66,10 +64,15 @@ def test_range(f_name, filter_range):
f"specified frequency_ranges_hz: {s['frequency_ranges_hz']}"
)

def compute_bs_features(self, spectrum_ch: np.array, features_compute: dict, ch_name: str, component: str, f_band: str) -> dict:

def compute_bs_features(
self,
spectrum_ch: np.array,
features_compute: dict,
ch_name: str,
component: str,
f_band: str,
) -> dict:
for bispectrum_feature in self.s["bispectrum"]["bispectrum_features"]:

if bispectrum_feature == "mean":
func = np.nanmean
if bispectrum_feature == "sum":
Expand All @@ -78,9 +81,25 @@ def compute_bs_features(self, spectrum_ch: np.array, features_compute: dict, ch_
func = np.nanvar

if f_band is not None:
str_feature = "_".join([ch_name, "Bispectrum", component, bispectrum_feature, f_band])
str_feature = "_".join(
[
ch_name,
"Bispectrum",
component,
bispectrum_feature,
f_band,
]
)
else:
str_feature = "_".join([ch_name, "Bispectrum", component, bispectrum_feature, "whole_fband_range"])
str_feature = "_".join(
[
ch_name,
"Bispectrum",
component,
bispectrum_feature,
"whole_fband_range",
]
)

features_compute[str_feature] = func(spectrum_ch)

Expand All @@ -89,13 +108,18 @@ def compute_bs_features(self, spectrum_ch: np.array, features_compute: dict, ch_
def calc_feature(self, data: np.array, features_compute: dict) -> dict:
for ch_idx, ch_name in enumerate(self.ch_names):
fft_coeffs, freqs = compute_fft(
data=np.expand_dims(data[ch_idx, :], axis=(0,1)),
data=np.expand_dims(data[ch_idx, :], axis=(0, 1)),
sampling_freq=self.sfreq,
n_points=data.shape[1],
verbose=False,
verbose=False,
)

f_spectrum_range = freqs[np.logical_and(freqs >= np.min([self.f1s, self.f2s]), freqs <= np.max([self.f1s, self.f2s]))]
f_spectrum_range = freqs[
np.logical_and(
freqs >= np.min([self.f1s, self.f2s]),
freqs <= np.max([self.f1s, self.f2s]),
)
]

waveshape = WaveShape(
data=fft_coeffs,
Expand All @@ -104,7 +128,9 @@ def calc_feature(self, data: np.array, features_compute: dict) -> dict:
verbose=False,
)

waveshape.compute(f1s=tuple(self.f1s[0], self.f1s[-1]), f2s=tuple(self.f2s[0], self.f2s[-1]))
waveshape.compute(
f1s=(self.f1s[0], self.f1s[-1]), f2s=(self.f2s[0], self.f2s[-1])
)

bispectrum = np.squeeze(waveshape.results._data)

Expand All @@ -120,14 +146,23 @@ def calc_feature(self, data: np.array, features_compute: dict) -> dict:
spectrum_ch = np.angle(bispectrum)

for fb in self.s["bispectrum"]["frequency_bands"]:
range_ = (f_spectrum_range >= self.s["frequency_ranges_hz"][fb][0]) \
& (f_spectrum_range <= self.s["frequency_ranges_hz"][fb][1])
#waveshape.results.plot()
data_bs = spectrum_ch[range_, range_]

features_compute = self.compute_bs_features(data_bs, features_compute, ch_name, component, fb)

if self.s["bispectrum"]["compute_features_for_whole_fband_range"]:
features_compute = self.compute_bs_features(spectrum_ch, features_compute, ch_name, component, None)
range_ = (
f_spectrum_range >= self.s["frequency_ranges_hz"][fb][0]
) & (
f_spectrum_range <= self.s["frequency_ranges_hz"][fb][1]
)
# waveshape.results.plot()
data_bs = spectrum_ch[range_, range_]

features_compute = self.compute_bs_features(
data_bs, features_compute, ch_name, component, fb
)

if self.s["bispectrum"][
"compute_features_for_whole_fband_range"
]:
features_compute = self.compute_bs_features(
spectrum_ch, features_compute, ch_name, component, None
)

return features_compute
9 changes: 6 additions & 3 deletions tests/test_bispectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
nm_IO,
nm_plots,
nm_settings,
nm_stream_offline
nm_stream_offline,
)

def test_bispectrum():

def test_bispectrum():
(
RUN_NAME,
PATH_RUN,
Expand Down Expand Up @@ -64,4 +64,7 @@ def test_bispectrum():

features = stream.run(np.expand_dims(data[3, :], axis=0))

assert features["ECOG_RIGHT_1_Bispectrum_phase_mean_whole_fband_range"].sum() != 0
assert (
features["ECOG_RIGHT_1_Bispectrum_phase_mean_whole_fband_range"].sum()
!= 0
)

0 comments on commit ce683a6

Please sign in to comment.