Non-local means denoising of multi-channel electrophysiology timeseries using PyTorch.
The software is in the alpha stage of development. Testers and contributers are welcome to assist.
Electrophysiology recordings contain spike waveforms superimposed on a noisy background signal. Each detected neuron may fire hundreds or even thousands of times over the duration of a recording, producing a very similar voltage trace at each event. Here we aim to leverage this redundancy in order to denoise the recording in a preprocessing step, prior to spike sorting or other analyses. Our approach is to use non-local means to suppress the noisy signal while retaining the part of the signal that repeats.
Non-local means is a denoising algorithm that has primarily been used for 2d
images, but it can also be adapted for use with 1d signals. Let
where
is a weighting function, and the summation is over all
The
where
where
is the singular-value decomposition of the
As a practical matter, we need to be able to perform the above denoising procedure within a reasonable timeframe relative to other processing steps. Below we discuss several strategies we use to speed up the computation.
The first simplification is to split the recording into discrete time blocks, typically 30 seconds or a minute, and denoise independently in each block. For firing rates greater around 1 Hz or higher this is okay since we only need a few dozen representative events for each neuron. In the future we may provide a way to also probe beyond the boundaries of the discrete blocks, but for now the user must choose a fixed duration for block sizes, with a tradeoff between block duration and computational efficiency.
The time-consuming part of the non-local means formula is the summation over all
While algorithms such as clustering or k-nearest neighbors could be used to more intelligently sample, we would like to avoid these types of methods in this preprocessing step. We view clustering and classification as part of the spike sorting step, and not this denoising operation which seeks only to isolate the signal from the noise.
The procedure we use for adaptive subsampling involves computing the summation
in batches and selectively dropping out clips from both the sources (
In addition to dropping out target clips from the computation, it is crucial to also drop out source clips. A large denominator for a target clip means that it must have a relatively large number of nearby neighbors. Therefore, the source clips that are overlapping (in time) to such a target clip would presumably also have a large number of neighbors, and in fact all of its nearby neighbors would be expected to have a large number of nearby neighbors. Thus it should be safe to drop out source clips as well based on the denominator criterion for the time-overlapping target clips.
In summary, adaptive subsampling is achieved by computing the weighted sum by accumulating the numerators and denominators in batches, while dropping out both source and target clips based on the denominator threshold criterion.
Here we need to describe how we keep track of numerators and denominators for each target clip, how exactly we apply the denominator dropout criterion, and how to combine the values for time-overlapping clips.
The method of non-local means works well when the signal repeats many times throughout the time block. But when neural events overlap in time the resulting waveform is a superposition that will usually not match any other waveform. Even if the same two neurons fire simultaneously in multiple instances, the time offset between the two events is expected to be variable, thus producing a spectrum of different superpositioned waveforms. Overlapping spike events are thus expected to form isolated clips that have few if any nearby neighbors.
While our method cannot be expected to denoise such events, we can expect the
noisy signal of the superimposed waveforms to survive the denoising process.
This is because only one source term (the clip itself) is expected to contribute
substantially (always with a weight of
Describe the procedure for denoising in neighborhoods.
Describe how we use PyTorch to efficiently compute the matrix-matrix multiplications needed in the above-described algorithm.
- Python (tested on 3.6 and 3.7)
- PyTorch (tested on v1.0.0)
- CUDA toolkit - if using GPU (recommended)
- MKL - if using CPU instead of CUDA
To test whether PyTorch and CUDA are set up properly, run the following in ipython:
import torch
if torch.cuda.is_available():
print('CUDA is available for PyTorch!')
else:
print('CUDA is NOT available for PyTorch.')
Recommended
- SpikeInterface --
pip install spikeinterface
- SpikeForest
pip install --upgrade ephys_nlm
After cloning this repository:
cd ephys_nlm
pip install -e .
# Then in subsequent updates:
git pull
pip install -e .
The following is taken from a notebook in the examples/ directory. It generates a short synthetic recording and denoise it.
from ephys_nlm import ephys_nlm_v1, ephys_nlm_v1_opts
import spikeextractors as se
import spikewidgets as sw
import matplotlib.pyplot as plt
# Create a synthetic recording for purposes of demo
recording, sorting_true = se.example_datasets.toy_example(duration=30, num_channels=4, K=20, seed=4)
# Specify the denoising options
opts = ephys_nlm_v1_opts(
multi_neighborhood=False, # False means all channels will be denoised in one neighborhood
block_size_sec=30, # Size of block in seconds -- each block is denoised separately
clip_size=30, # Size of a clip (aka snippet) in timepoints
sigma='auto', # Auto compute the noise level
sigma_scale_factor=1, # Scale factor for auto-computed noise level
whitening='auto', # Auto compute the whitening matrix
whitening_pctvar=90, # Percent of variance to retain - controls number of SVD components to keep
denom_threshold=30 # Higher values lead to a slower but more accurate calculation.
)
# Do the denoising
recording_denoised, runtim_info = ephys_nlm_v1(
recording=recording,
opts=opts,
device='cpu', # cuda is recommended for non-demo situations
verbose=1
)
Also included in the notebook is SpikeInterface code used to view the original and denoised timeseries:
# View the original and denoised timeseries
plt.figure(figsize=(16,5))
sw.TimeseriesWidget(recording=recording, trange=(0, 0.2), ax=plt.gca()).plot();
plt.figure(figsize=(16,5))
sw.TimeseriesWidget(recording=recording_denoised, trange=(0, 0.2), ax=plt.gca()).plot();
This should produce output similar to the following:
Apache-2.0 -- We request that you acknowledge the original authors in any derivative work.
Jeremy Magland, Center for Computational Mathematics (CCM), Flatiron Institute
Alex Barnett, James Jun, and members of CCM for many useful discussions