diff --git a/py_neuromodulation/nm_bursts.py b/py_neuromodulation/nm_bursts.py index dbd43cb4..8f8c59e2 100644 --- a/py_neuromodulation/nm_bursts.py +++ b/py_neuromodulation/nm_bursts.py @@ -167,25 +167,18 @@ def get_burst_amplitude_length( bursts = np.zeros((beta_averp_norm.shape[0] + 1), dtype=bool) bursts[1:] = beta_averp_norm >= burst_thr deriv = np.diff(bursts) - isburst = False burst_length = [] burst_amplitude = [] - burst_start = 0 - for index, burst_state in enumerate(deriv): - if burst_state == True: - if isburst == True: - burst_length.append(index - burst_start) - burst_amplitude.append(beta_averp_norm[burst_start:index]) + burst_time_points = np.where(deriv==True)[0] - isburst = False - else: - burst_start = index - isburst = True + for i in range(burst_time_points.size//2): + burst_length.append(burst_time_points[2 * i + 1] - burst_time_points[2 * i]) + burst_amplitude.append(beta_averp_norm[burst_time_points[2 * i] : burst_time_points[2 * i + 1]]) # the last burst length (in case isburst == True) is omitted, # since the true burst length cannot be estimated - burst_length = np.array(burst_length) / sfreq return burst_amplitude, burst_length + diff --git a/py_neuromodulation/nm_normalization.py b/py_neuromodulation/nm_normalization.py index 865598d8..b4b14cd7 100644 --- a/py_neuromodulation/nm_normalization.py +++ b/py_neuromodulation/nm_normalization.py @@ -3,8 +3,6 @@ from sklearn import preprocessing import numpy as np - - class NORM_METHODS(Enum): MEAN = "mean" MEDIAN = "median" @@ -138,6 +136,17 @@ def process(self, data: np.ndarray) -> np.ndarray: return data +""" +Functions to check for NaN's before deciding which Numpy function to call +""" +def nan_mean(data, axis): + return np.nanmean(data, axis=axis) if np.any(np.isnan(sum(data))) else np.mean(data, axis=axis) + +def nan_std(data, axis): + return np.nanstd(data, axis=axis) if np.any(np.isnan(sum(data))) else np.std(data, axis=axis) + +def nan_median(data, axis): + return np.nanmedian(data, axis=axis) if np.any(np.isnan(sum(data))) else np.median(data, axis=axis) def _normalize_and_clip( current: np.ndarray, @@ -147,82 +156,49 @@ def _normalize_and_clip( description: str, ) -> tuple[np.ndarray, np.ndarray]: """Normalize data.""" - if method == NORM_METHODS.MEAN.value: - mean = np.nanmean(previous, axis=0) - current = (current - mean) / mean - elif method == NORM_METHODS.MEDIAN.value: - median = np.nanmedian(previous, axis=0) - current = (current - median) / median - elif method == NORM_METHODS.ZSCORE.value: - mean = np.nanmean(previous, axis=0) - current = (current - mean) / np.nanstd(previous, axis=0) - elif method == NORM_METHODS.ZSCORE_MEDIAN.value: - current = (current - np.nanmedian(previous, axis=0)) / np.nanstd( - previous, axis=0 - ) - # For the following methods we check for the shape of current - # when current is a 1D array, then it is the post-processing normalization, - # and we need to expand, and take the [0, :] component - # When current is a 2D array, then it is pre-processing normalization, and - # there's no need for expanding. - elif method == NORM_METHODS.QUANTILE.value: - if len(current.shape) == 1: - current = ( - preprocessing.QuantileTransformer(n_quantiles=300) - .fit(np.nan_to_num(previous)) - .transform(np.expand_dims(current, axis=0))[0, :] - ) - else: - current = ( - preprocessing.QuantileTransformer(n_quantiles=300) - .fit(np.nan_to_num(previous)) - .transform(current) - ) - elif method == NORM_METHODS.ROBUST.value: - if len(current.shape) == 1: - current = ( - preprocessing.RobustScaler() - .fit(np.nan_to_num(previous)) - .transform(np.expand_dims(current, axis=0))[0, :] - ) - else: + match method: + case NORM_METHODS.MEAN.value: + mean = nan_mean(previous, axis=0) + current = (current - mean) / mean + case NORM_METHODS.MEDIAN.value: + median = nan_median(previous, axis=0) + current = (current - median) / median + case NORM_METHODS.ZSCORE.value: + current = (current - nan_mean(previous, axis=0)) / nan_std(previous, axis=0) + case NORM_METHODS.ZSCORE_MEDIAN.value: + current = (current - nan_median(previous, axis=0)) / nan_std(previous, axis=0) + # For the following methods we check for the shape of current + # when current is a 1D array, then it is the post-processing normalization, + # and we need to expand, and remove the extra dimension afterwards + # When current is a 2D array, then it is pre-processing normalization, and + # there's no need for expanding. + case (NORM_METHODS.QUANTILE.value | + NORM_METHODS.ROBUST.value | + NORM_METHODS.MINMAX.value | + NORM_METHODS.POWER.value): + + norm_methods = { + NORM_METHODS.QUANTILE.value : lambda: preprocessing.QuantileTransformer(n_quantiles=300), + NORM_METHODS.ROBUST.value : preprocessing.RobustScaler, + NORM_METHODS.MINMAX.value : preprocessing.MinMaxScaler, + NORM_METHODS.POWER.value : preprocessing.PowerTransformer + } + current = ( - preprocessing.RobustScaler() + norm_methods[method]() .fit(np.nan_to_num(previous)) - .transform(current) + .transform( + # if post-processing: pad dimensions to 2 + np.reshape(current, (2-len(current.shape))*(1,) + current.shape) + ) + .squeeze() # if post-processing: remove extra dimension ) - - elif method == NORM_METHODS.MINMAX.value: - if len(current.shape) == 1: - current = ( - preprocessing.MinMaxScaler() - .fit(np.nan_to_num(previous)) - .transform(np.expand_dims(current, axis=0))[0, :] - ) - else: - current = ( - preprocessing.MinMaxScaler() - .fit(np.nan_to_num(previous)) - .transform(current) - ) - elif method == NORM_METHODS.POWER.value: - if len(current.shape) == 1: - current = ( - preprocessing.PowerTransformer() - .fit(np.nan_to_num(previous)) - .transform(np.expand_dims(current, axis=0))[0, :] + + case _: + raise ValueError( + f"Only {[e.value for e in NORM_METHODS]} are supported as " + f"{description} normalization methods. Got {method}." ) - else: - current = ( - preprocessing.PowerTransformer() - .fit(np.nan_to_num(previous)) - .transform(current) - ) - else: - raise ValueError( - f"Only {[e.value for e in NORM_METHODS]} are supported as " - f"{description} normalization methods. Got {method}." - ) if clip: current = _clip(data=current, clip=clip)