Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…ulation into main
timonmerk committed Jan 28, 2024

Verified

This commit was signed with the committer’s verified signature.
stephancill Stephan Cilliers
2 parents 018d18b + 906f50d commit 8df33fc
Showing 2 changed files with 55 additions and 86 deletions.
17 changes: 5 additions & 12 deletions py_neuromodulation/nm_bursts.py
Original file line number Diff line number Diff line change
@@ -167,25 +167,18 @@ def get_burst_amplitude_length(
bursts = np.zeros((beta_averp_norm.shape[0] + 1), dtype=bool)
bursts[1:] = beta_averp_norm >= burst_thr
deriv = np.diff(bursts)
isburst = False
burst_length = []
burst_amplitude = []
burst_start = 0

for index, burst_state in enumerate(deriv):
if burst_state == True:
if isburst == True:
burst_length.append(index - burst_start)
burst_amplitude.append(beta_averp_norm[burst_start:index])
burst_time_points = np.where(deriv==True)[0]

isburst = False
else:
burst_start = index
isburst = True
for i in range(burst_time_points.size//2):
burst_length.append(burst_time_points[2 * i + 1] - burst_time_points[2 * i])
burst_amplitude.append(beta_averp_norm[burst_time_points[2 * i] : burst_time_points[2 * i + 1]])

# the last burst length (in case isburst == True) is omitted,
# since the true burst length cannot be estimated

burst_length = np.array(burst_length) / sfreq

return burst_amplitude, burst_length

124 changes: 50 additions & 74 deletions py_neuromodulation/nm_normalization.py
Original file line number Diff line number Diff line change
@@ -3,8 +3,6 @@

from sklearn import preprocessing
import numpy as np


class NORM_METHODS(Enum):
MEAN = "mean"
MEDIAN = "median"
@@ -138,6 +136,17 @@ def process(self, data: np.ndarray) -> np.ndarray:

return data

"""
Functions to check for NaN's before deciding which Numpy function to call
"""
def nan_mean(data, axis):
return np.nanmean(data, axis=axis) if np.any(np.isnan(sum(data))) else np.mean(data, axis=axis)

def nan_std(data, axis):
return np.nanstd(data, axis=axis) if np.any(np.isnan(sum(data))) else np.std(data, axis=axis)

def nan_median(data, axis):
return np.nanmedian(data, axis=axis) if np.any(np.isnan(sum(data))) else np.median(data, axis=axis)

def _normalize_and_clip(
current: np.ndarray,
@@ -147,82 +156,49 @@ def _normalize_and_clip(
description: str,
) -> tuple[np.ndarray, np.ndarray]:
"""Normalize data."""
if method == NORM_METHODS.MEAN.value:
mean = np.nanmean(previous, axis=0)
current = (current - mean) / mean
elif method == NORM_METHODS.MEDIAN.value:
median = np.nanmedian(previous, axis=0)
current = (current - median) / median
elif method == NORM_METHODS.ZSCORE.value:
mean = np.nanmean(previous, axis=0)
current = (current - mean) / np.nanstd(previous, axis=0)
elif method == NORM_METHODS.ZSCORE_MEDIAN.value:
current = (current - np.nanmedian(previous, axis=0)) / np.nanstd(
previous, axis=0
)
# For the following methods we check for the shape of current
# when current is a 1D array, then it is the post-processing normalization,
# and we need to expand, and take the [0, :] component
# When current is a 2D array, then it is pre-processing normalization, and
# there's no need for expanding.
elif method == NORM_METHODS.QUANTILE.value:
if len(current.shape) == 1:
current = (
preprocessing.QuantileTransformer(n_quantiles=300)
.fit(np.nan_to_num(previous))
.transform(np.expand_dims(current, axis=0))[0, :]
)
else:
current = (
preprocessing.QuantileTransformer(n_quantiles=300)
.fit(np.nan_to_num(previous))
.transform(current)
)
elif method == NORM_METHODS.ROBUST.value:
if len(current.shape) == 1:
current = (
preprocessing.RobustScaler()
.fit(np.nan_to_num(previous))
.transform(np.expand_dims(current, axis=0))[0, :]
)
else:
match method:
case NORM_METHODS.MEAN.value:
mean = nan_mean(previous, axis=0)
current = (current - mean) / mean
case NORM_METHODS.MEDIAN.value:
median = nan_median(previous, axis=0)
current = (current - median) / median
case NORM_METHODS.ZSCORE.value:
current = (current - nan_mean(previous, axis=0)) / nan_std(previous, axis=0)
case NORM_METHODS.ZSCORE_MEDIAN.value:
current = (current - nan_median(previous, axis=0)) / nan_std(previous, axis=0)
# For the following methods we check for the shape of current
# when current is a 1D array, then it is the post-processing normalization,
# and we need to expand, and remove the extra dimension afterwards
# When current is a 2D array, then it is pre-processing normalization, and
# there's no need for expanding.
case (NORM_METHODS.QUANTILE.value |
NORM_METHODS.ROBUST.value |
NORM_METHODS.MINMAX.value |
NORM_METHODS.POWER.value):

norm_methods = {
NORM_METHODS.QUANTILE.value : lambda: preprocessing.QuantileTransformer(n_quantiles=300),
NORM_METHODS.ROBUST.value : preprocessing.RobustScaler,
NORM_METHODS.MINMAX.value : preprocessing.MinMaxScaler,
NORM_METHODS.POWER.value : preprocessing.PowerTransformer
}

current = (
preprocessing.RobustScaler()
norm_methods[method]()
.fit(np.nan_to_num(previous))
.transform(current)
.transform(
# if post-processing: pad dimensions to 2
np.reshape(current, (2-len(current.shape))*(1,) + current.shape)
)
.squeeze() # if post-processing: remove extra dimension
)

elif method == NORM_METHODS.MINMAX.value:
if len(current.shape) == 1:
current = (
preprocessing.MinMaxScaler()
.fit(np.nan_to_num(previous))
.transform(np.expand_dims(current, axis=0))[0, :]
)
else:
current = (
preprocessing.MinMaxScaler()
.fit(np.nan_to_num(previous))
.transform(current)
)
elif method == NORM_METHODS.POWER.value:
if len(current.shape) == 1:
current = (
preprocessing.PowerTransformer()
.fit(np.nan_to_num(previous))
.transform(np.expand_dims(current, axis=0))[0, :]

case _:
raise ValueError(
f"Only {[e.value for e in NORM_METHODS]} are supported as "
f"{description} normalization methods. Got {method}."
)
else:
current = (
preprocessing.PowerTransformer()
.fit(np.nan_to_num(previous))
.transform(current)
)
else:
raise ValueError(
f"Only {[e.value for e in NORM_METHODS]} are supported as "
f"{description} normalization methods. Got {method}."
)

if clip:
current = _clip(data=current, clip=clip)

0 comments on commit 8df33fc

Please sign in to comment.