diff --git a/.gitignore b/.gitignore index bd08f9ba..6397ccf7 100644 --- a/.gitignore +++ b/.gitignore @@ -158,4 +158,6 @@ examples/utils/* sub/ -py_neuromodulation/data/derivatives/* \ No newline at end of file +py_neuromodulation/data/derivatives/* + +plot_0_first_demo.py diff --git a/py_neuromodulation/nm_filter.py b/py_neuromodulation/nm_filter.py index b4ecebb3..4070e9c1 100644 --- a/py_neuromodulation/nm_filter.py +++ b/py_neuromodulation/nm_filter.py @@ -1,4 +1,5 @@ """Module for filter functionality.""" + import mne from mne.filter import _overlap_add_filter import numpy as np @@ -53,17 +54,27 @@ def __init__( f_ranges = [f_ranges] for f_range in f_ranges: - filt = mne.filter.create_filter( - None, - sfreq, - l_freq=f_range[0], - h_freq=f_range[1], - fir_design="firwin", - l_trans_bandwidth=l_trans_bandwidth, # type: ignore - h_trans_bandwidth=h_trans_bandwidth, # type: ignore - filter_length=filter_length, # type: ignore - verbose=verbose, - ) + try: + filt = mne.filter.create_filter( + None, + sfreq, + l_freq=f_range[0], + h_freq=f_range[1], + fir_design="firwin", + l_trans_bandwidth=l_trans_bandwidth, # type: ignore + h_trans_bandwidth=h_trans_bandwidth, # type: ignore + filter_length=filter_length, # type: ignore + verbose=verbose, + ) + except: + filt = mne.filter.create_filter( + None, + sfreq, + l_freq=f_range[0], + h_freq=f_range[1], + fir_design="firwin", + verbose=verbose, + ) filter_bank.append(filt) self.filter_bank = np.vstack(filter_bank) @@ -82,10 +93,6 @@ def filter_data(self, data: np.ndarray) -> np.ndarray: np.ndarray, shape (n_channels, n_fbands, n_samples) Filtered data. - Raises - ------ - ValueError - If data.ndim > 2 """ if data.ndim > 2: raise ValueError( @@ -94,15 +101,14 @@ def filter_data(self, data: np.ndarray) -> np.ndarray: ) if data.ndim == 1: data = np.expand_dims(data, axis=0) + filtered = np.array( [ - [ - np.convolve(flt, chan, mode="same") - for flt in self.filter_bank - ] + [np.convolve(self.filter_bank[0, :], chan, mode="same")] for chan in data ] ) + return filtered @@ -175,19 +181,19 @@ def __init__( phase="zero", fir_window="hamming", fir_design="firwin", - verbose=False + verbose=False, ) def process(self, data: np.ndarray) -> np.ndarray: if self.filter_bank is None: return data return _overlap_add_filter( - x=data, - h=self.filter_bank, - n_fft=None, - phase="zero", - picks=None, - n_jobs=1, - copy=True, - pad="reflect_limited", - ) + x=data, + h=self.filter_bank, + n_fft=None, + phase="zero", + picks=None, + n_jobs=1, + copy=True, + pad="reflect_limited", + ) diff --git a/py_neuromodulation/nm_filter_preprocessing.py b/py_neuromodulation/nm_filter_preprocessing.py index 32010b28..3da9f784 100644 --- a/py_neuromodulation/nm_filter_preprocessing.py +++ b/py_neuromodulation/nm_filter_preprocessing.py @@ -5,9 +5,7 @@ class PreprocessingFilter: - def __init__( - self, settings: dict, sfreq: int | float - ) -> None: + def __init__(self, settings: dict, sfreq: int | float) -> None: self.s = settings self.sfreq = sfreq self.filters = [] @@ -18,10 +16,10 @@ def __init__( f_ranges=[ self.s["preprocessing_filter"][ "bandstop_filter_settings" - ]["frequency_low_hz"], + ]["frequency_high_hz"], self.s["preprocessing_filter"][ "bandstop_filter_settings" - ]["frequency_high_hz"], + ]["frequency_low_hz"], ], sfreq=self.sfreq, filter_length=self.sfreq - 1, @@ -87,6 +85,7 @@ def process(self, data: np.ndarray) -> np.ndarray: """ for filter in self.filters: - data = filter.filter_data(data) - - return data + data = filter.filter_data( + data if len(data.shape) == 2 else data[:, 0, :] + ) + return data if len(data.shape) == 2 else data[:, 0, :] diff --git a/py_neuromodulation/nm_settings.json b/py_neuromodulation/nm_settings.json index a16feb86..8e722220 100644 --- a/py_neuromodulation/nm_settings.json +++ b/py_neuromodulation/nm_settings.json @@ -4,8 +4,7 @@ "preprocessing": [ "raw_resampling", "notch_filter", - "re_referencing", - "preprocessing_filter" + "re_referencing" ], "documentation_preprocessing_options": [ "raw_resampling", @@ -38,23 +37,23 @@ "raw_resampling_settings": { "resample_freq_hz": 1000 }, - "preprocessing_filter" : { - "bandstop_filter" : true, - "lowpass_filter" : true, - "highpass_filter" : true, - "bandpass_filter" : true, - "bandstop_filter_settings" : { + "preprocessing_filter": { + "bandstop_filter": true, + "lowpass_filter": true, + "highpass_filter": true, + "bandpass_filter": true, + "bandstop_filter_settings": { "frequency_low_hz": 100, "frequency_high_hz": 160 }, - "lowpass_filter_settings" : { + "lowpass_filter_settings": { "frequency_cutoff_hz": 200 }, - "highpass_filter_settings" : { - "frequency_cutoff_hz": 1 + "highpass_filter_settings": { + "frequency_cutoff_hz": 3 }, - "bandpass_filter_settings" : { - "frequency_low_hz": 1, + "bandpass_filter_settings": { + "frequency_low_hz": 3, "frequency_high_hz": 200 } }, diff --git a/tests/conftest.py b/tests/conftest.py index 0022c9b6..93cae015 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,16 +1,80 @@ import pytest import numpy as np -from py_neuromodulation.nm_rereference import ReReferencer from py_neuromodulation import ( nm_generator, + nm_stream_offline, nm_settings, nm_IO, nm_define_nmchannels, ) + +@pytest.fixture +def setup_default_stream_fast_compute(): + """This test function sets a data batch and automatic initialized M1 datafram + + Args: + PATH_PYNEUROMODULATION (string): Path to py_neuromodulation repository + + Returns: + ieeg_batch (np.ndarray): (channels, samples) + df_M1 (pd Dataframe): auto intialized table for rereferencing + settings_wrapper (settings.py): settings.json + fs (float): example sampling frequency + """ + + ( + RUN_NAME, + PATH_RUN, + PATH_BIDS, + PATH_OUT, + datatype, + ) = nm_IO.get_paths_example_data() + + ( + raw, + data, + sfreq, + line_noise, + coord_list, + coord_names, + ) = nm_IO.read_BIDS_data( + PATH_RUN=PATH_RUN, BIDS_PATH=PATH_BIDS, datatype=datatype + ) + + nm_channels = nm_define_nmchannels.set_channels( + ch_names=raw.ch_names, + ch_types=raw.get_channel_types(), + reference="default", + bads=raw.info["bads"], + new_names="default", + used_types=("ecog", "dbs", "seeg"), + target_keywords=("MOV_RIGHT_CLEAN",), + ) + + settings = nm_settings.get_default_settings() + settings = nm_settings.reset_settings(settings) + settings["fooof"]["aperiodic"]["exponent"] = True + settings["fooof"]["aperiodic"]["offset"] = True + settings["features"]["fooof"] = True + + stream = nm_stream_offline.Stream( + settings=settings, + nm_channels=nm_channels, + path_grids=None, + verbose=True, + sfreq=sfreq, + line_noise=line_noise, + coord_list=coord_list, + coord_names=coord_names, + ) + + return data, stream + + @pytest.fixture -def setup(): +def setup_databatch(): """This test function sets a data batch and automatic initialized M1 datafram Args: diff --git a/tests/test_fooof.py b/tests/test_fooof.py index 5dc2a559..1ed10636 100644 --- a/tests/test_fooof.py +++ b/tests/test_fooof.py @@ -1,79 +1,9 @@ -import pytest +from py_neuromodulation import nm_generator -from py_neuromodulation import ( - nm_generator, - nm_stream_offline, - nm_IO, - nm_define_nmchannels, - nm_settings, -) +def test_fooof_features(setup_default_stream_fast_compute): -@pytest.fixture -def setup_stream(): - """This test function sets a data batch and automatic initialized M1 datafram - - Args: - PATH_PYNEUROMODULATION (string): Path to py_neuromodulation repository - - Returns: - ieeg_batch (np.ndarray): (channels, samples) - df_M1 (pd Dataframe): auto intialized table for rereferencing - settings_wrapper (settings.py): settings.json - fs (float): example sampling frequency - """ - - ( - RUN_NAME, - PATH_RUN, - PATH_BIDS, - PATH_OUT, - datatype, - ) = nm_IO.get_paths_example_data() - - ( - raw, - data, - sfreq, - line_noise, - coord_list, - coord_names, - ) = nm_IO.read_BIDS_data( - PATH_RUN=PATH_RUN, BIDS_PATH=PATH_BIDS, datatype=datatype - ) - - nm_channels = nm_define_nmchannels.set_channels( - ch_names=raw.ch_names, - ch_types=raw.get_channel_types(), - reference="default", - bads=raw.info["bads"], - new_names="default", - used_types=("ecog", "dbs", "seeg"), - target_keywords=("MOV_RIGHT_CLEAN",), - ) - - settings = nm_settings.get_default_settings() - settings = nm_settings.reset_settings(settings) - settings["fooof"]["aperiodic"]["exponent"] = True - settings["fooof"]["aperiodic"]["offset"] = True - settings["features"]["fooof"] = True - - stream = nm_stream_offline.Stream( - settings=settings, - nm_channels=nm_channels, - path_grids=None, - verbose=True, - sfreq=sfreq, - line_noise=line_noise, - coord_list=coord_list, - coord_names=coord_names, - ) - - return data, stream - -def test_fooof_features(setup_stream): - - data, stream = setup_stream + data, stream = setup_default_stream_fast_compute generator = nm_generator.raw_data_generator( data, stream.settings, stream.sfreq diff --git a/tests/test_preprocessing_filter.py b/tests/test_preprocessing_filter.py index ee88f08a..139e680b 100644 --- a/tests/test_preprocessing_filter.py +++ b/tests/test_preprocessing_filter.py @@ -1,16 +1,151 @@ +import numpy as np +from scipy import signal +from py_neuromodulation import nm_settings from py_neuromodulation.nm_filter_preprocessing import PreprocessingFilter -def test_preprocessing_filter(setup): - ch_names, ch_types, bads, data_batch = setup + +def test_preprocessing_within_pipeline(setup_default_stream_fast_compute): + + data, stream = setup_default_stream_fast_compute + + stream.settings["preprocessing"].append("preprocessing_filter") + + stream.settings["preprocessing_filter"]["bandstop_filter"] = True + stream.settings["preprocessing_filter"]["bandpass_filter"] = True + stream.settings["preprocessing_filter"]["lowpass_filter"] = True + stream.settings["preprocessing_filter"]["highpass_filter"] = True + + stream.sfreq + + try: + _ = stream.run(data[:, : int(stream.sfreq * 2)]) + except Exception as e: + assert False, f"Error in pipeline including preprocess filtering : {e}" + + +def test_preprocessing_filter_lowpass(): + + data_batch = np.random.random([1, 1000]) + + settings = nm_settings.get_default_settings() + settings["preprocessing"] = settings["preprocessing"].append( + "preprocessing_filter" + ) + settings["preprocessing_filter"]["lowpass_filter"] = True + settings["preprocessing_filter"]["highpass_filter"] = False + settings["preprocessing_filter"]["bandpass_filter"] = False + settings["preprocessing_filter"]["bandstop_filter"] = False + + settings["preprocessing_filter"]["lowpass_filter_settings"][ + "frequency_cutoff_hz" + ] = 100 + + sfreq = 1000 + + preprocessing_filter = PreprocessingFilter(settings, sfreq) + data_filtered = preprocessing_filter.process(data_batch) + + # compute a scipy signal welch to check if the filter worked + f, Pxx = signal.welch(data_batch, fs=sfreq, nperseg=1000) + f, Pxx_f = signal.welch(data_filtered, fs=sfreq, nperseg=1000) + + # check if the power in the frequency range of the lowpass filter is reduced + assert np.mean(Pxx_f[0, 100:500]) < np.mean(Pxx[0, 100:500]) + + +def test_preprocessing_filter_highpass(): + + data_batch = np.random.random([1, 1000]) + + settings = nm_settings.get_default_settings() + settings["preprocessing"] = settings["preprocessing"].append( + "preprocessing_filter" + ) + settings["preprocessing_filter"]["highpass_filter"] = True + settings["preprocessing_filter"]["lowpass_filter"] = False + settings["preprocessing_filter"]["bandpass_filter"] = False + settings["preprocessing_filter"]["bandstop_filter"] = False + + settings["preprocessing_filter"]["highpass_filter_settings"][ + "frequency_cutoff_hz" + ] = 100 + + sfreq = 1000 + + preprocessing_filter = PreprocessingFilter(settings, sfreq) + data_filtered = preprocessing_filter.process(data_batch) + + # compute a scipy signal welch to check if the filter worked + f, Pxx = signal.welch(data_batch, fs=sfreq, nperseg=1000) + f, Pxx_f = signal.welch(data_filtered, fs=sfreq, nperseg=1000) + + # check if the power in the frequency range of the highpass filter is reduced + assert np.mean(Pxx_f[0, 0:100]) < np.mean(Pxx[0, 0:100]) + + +def test_preprocessing_filter_bandstop(): + + data_batch = np.random.random([1, 1000]) settings = nm_settings.get_default_settings() - settings = nm_settings.set_settings_fast_compute(settings) + settings["preprocessing"] = settings["preprocessing"].append( + "preprocessing_filter" + ) + settings["preprocessing_filter"]["bandstop_filter"] = True + settings["preprocessing_filter"]["bandpass_filter"] = False + settings["preprocessing_filter"]["lowpass_filter"] = False + settings["preprocessing_filter"]["highpass_filter"] = False + + settings["preprocessing_filter"]["bandstop_filter_settings"][ + "frequency_low_hz" + ] = 100 + settings["preprocessing_filter"]["bandstop_filter_settings"][ + "frequency_high_hz" + ] = 160 + + sfreq = 1000 preprocessing_filter = PreprocessingFilter(settings, sfreq) + data_filtered = preprocessing_filter.process(data_batch) + + # compute a scipy signal welch to check if the filter worked + f, Pxx = signal.welch(data_batch, fs=sfreq, nperseg=1000) + f, Pxx_f = signal.welch(data_filtered, fs=sfreq, nperseg=1000) + + # check if the power in the frequency range of the bandstop filter is reduced + assert np.mean(Pxx_f[0, 100:160]) < np.mean(Pxx[0, 100:160]) + + +def test_preprocessing_filter_bandpass(): + + data_batch = np.random.random([1, 1000]) + settings = nm_settings.get_default_settings() + settings["preprocessing"] = settings["preprocessing"].append( + "preprocessing_filter" + ) + settings["preprocessing_filter"]["bandstop_filter"] = False + settings["preprocessing_filter"]["bandpass_filter"] = True + settings["preprocessing_filter"]["lowpass_filter"] = False + settings["preprocessing_filter"]["highpass_filter"] = False + + settings["preprocessing_filter"]["bandpass_filter_settings"][ + "frequency_low_hz" + ] = 100 + settings["preprocessing_filter"]["bandpass_filter_settings"][ + "frequency_high_hz" + ] = 160 + + sfreq = 1000 + + preprocessing_filter = PreprocessingFilter(settings, sfreq) data_filtered = preprocessing_filter.process(data_batch) - assert data_filtered.shape == data_batch.shape + # compute a scipy signal welch to check if the filter worked + f, Pxx = signal.welch(data_batch, fs=sfreq, nperseg=1000) + f, Pxx_f = signal.welch(data_filtered, fs=sfreq, nperseg=1000) - assert data_filtered != data_batch \ No newline at end of file + # check if the power in the frequency range of the bandpass filter is reduced + assert np.mean(Pxx_f[0, 0:100]) < np.mean(Pxx[0, 0:100]) + assert np.mean(Pxx_f[0, 160:500]) < np.mean(Pxx[0, 160:500]) diff --git a/tests/test_rereference.py b/tests/test_rereference.py index 54c5fbad..341e1f80 100644 --- a/tests/test_rereference.py +++ b/tests/test_rereference.py @@ -14,8 +14,8 @@ ) -def test_rereference_not_used_channels_no_reref(setup): - ch_names, ch_types, bads, data_batch = setup +def test_rereference_not_used_channels_no_reref(setup_databatch): + ch_names, ch_types, bads, data_batch = setup_databatch nm_channels = nm_define_nmchannels.set_channels( ch_names=ch_names,