Skip to content

Commit

Permalink
Fix ICLabel topographic features on ICA fitted with channel selection (
Browse files Browse the repository at this point in the history
…#68)

* fix channel selection for topoplot

* fix style

* fix tests

* add test

* use testing instead of sample dataset

* fix style

* add entry to changelog

* add gh id
  • Loading branch information
Mathieu Scheltienne authored Jun 16, 2022
1 parent a411e45 commit ab090dd
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 29 deletions.
1 change: 1 addition & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Bug
- Fix shape of ``'y_pred_proba'`` output from :func:`mne_icalabel.label_components` by `Mathieu Scheltienne`_ (:gh:`36`)
- Add a warning if the ICA decomposition provided does not match the expected decomposition by ``ICLabel`` by `Mathieu Scheltienne`_ (:gh:`42`)
- Fix extraction of PSD feature from ``ICLabel`` model on epochs by `Mathieu Scheltienne`_ (:gh:`64`)
- Fix ICLabel topographic features on ICA fitted with a channel selection performed by ``picks`` by `Mathieu Scheltienne`_ (:gh:`68`)

API
~~~
Expand Down
11 changes: 6 additions & 5 deletions mne_icalabel/iclabel/features.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, Union
from typing import List, Tuple, Union

import numpy as np
from mne import BaseEpochs
Expand Down Expand Up @@ -76,7 +76,7 @@ def get_iclabel_features(inst: Union[BaseRaw, BaseEpochs], ica: ICA):
icaact = _compute_ica_activations(inst, ica)

# compute topographic feature (float32)
topo = _eeg_topoplot(inst, icawinv)
topo = _eeg_topoplot(inst, icawinv, ica.ch_names)

# compute psd feature (float32)
psd = _eeg_rpsd(inst, ica, icaact)
Expand Down Expand Up @@ -168,12 +168,13 @@ def _compute_ica_activations(inst: Union[BaseRaw, BaseEpochs], ica: ICA) -> NDAr


# ----------------------------------------------------------------------------
def _eeg_topoplot(inst: Union[BaseRaw, BaseEpochs], icawinv: NDArray[float]) -> NDArray[float]:
def _eeg_topoplot(
inst: Union[BaseRaw, BaseEpochs], icawinv: NDArray[float], picks: List[str]
) -> NDArray[float]:
"""Topoplot feature."""
# TODO: Selection of channels is missing.
ncomp = icawinv.shape[-1]
topo = np.zeros((32, 32, 1, ncomp))
rd, th = _mne_to_eeglab_locs(inst)
rd, th = _mne_to_eeglab_locs(inst, picks)
th = np.pi / 180 * th # convert degrees to radians
for it in range(ncomp):
temp_topo = _topoplotFast(icawinv[:, it], rd, th)
Expand Down
4 changes: 2 additions & 2 deletions mne_icalabel/iclabel/tests/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_topoplotFast(file, eeglab_result_file):
# load ICA
ica = read_ica_eeglab(file)
# convert coordinates
rd, th = _mne_to_eeglab_locs(inst)
rd, th = _mne_to_eeglab_locs(inst, ica.ch_names)
th = np.pi / 180 * th
# get icawinv
icawinv, _ = _retrieve_eeglab_icawinv(ica)
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_eeg_topoplot(file, eeglab_result_file):
# get icawinv
icawinv, _ = _retrieve_eeglab_icawinv(ica)
# compute feature
topo = _eeg_topoplot(inst, icawinv)
topo = _eeg_topoplot(inst, icawinv, ica.ch_names)
# load from eeglab
topo_eeglab = loadmat(eeglab_result_file)["topo"]
# compare
Expand Down
36 changes: 21 additions & 15 deletions mne_icalabel/iclabel/tests/test_label_components.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import numpy as np
import pytest
from mne import create_info, make_fixed_length_epochs
from mne.datasets import sample
from mne import create_info, make_fixed_length_epochs, pick_types
from mne.datasets import testing
from mne.io import RawArray, read_raw
from mne.preprocessing import ICA

from mne_icalabel.iclabel import iclabel_label_components

directory = sample.data_path() / "MEG" / "sample"
raw = read_raw(directory / "sample_audvis_raw.fif", preload=False)
raw.crop(0, 100).pick_types(eeg=True, exclude="bads")
directory = testing.data_path() / "MEG" / "sample"
raw = read_raw(directory / "sample_audvis_trunc_raw.fif", preload=False)
raw.pick_types(eeg=True, exclude=[])
raw.load_data()
# preprocess
raw.filter(l_freq=1.0, h_freq=100.0)
Expand All @@ -18,21 +18,27 @@

@pytest.mark.filterwarnings("ignore::RuntimeWarning")
@pytest.mark.parametrize(
"inst",
"inst, exclude",
(
raw,
raw.copy().crop(0, 10),
raw.copy().crop(0, 1),
make_fixed_length_epochs(raw, duration=0.5),
make_fixed_length_epochs(raw, duration=1),
make_fixed_length_epochs(raw, duration=5),
make_fixed_length_epochs(raw, duration=10),
(raw, "bads"),
(raw.copy().crop(0, 8), "bads"),
(raw.copy().crop(0, 1), "bads"),
(make_fixed_length_epochs(raw, duration=0.5, preload=True), "bads"),
(make_fixed_length_epochs(raw, duration=1, preload=True), "bads"),
(make_fixed_length_epochs(raw, duration=5, preload=True), "bads"),
(raw, []),
(raw.copy().crop(0, 8), []),
(raw.copy().crop(0, 1), []),
(make_fixed_length_epochs(raw, duration=0.5, preload=True), []),
(make_fixed_length_epochs(raw, duration=1, preload=True), []),
(make_fixed_length_epochs(raw, duration=5, preload=True), []),
),
)
def test_label_components(inst):
def test_label_components(inst, exclude):
"""Check that label_components does not raise on various data shapes."""
picks = pick_types(raw.info, eeg=True, exclude=exclude)
ica = ICA(n_components=5, method="picard", fit_params=dict(ortho=False, extended=True))
ica.fit(inst)
ica.fit(inst, picks=picks)
labels = iclabel_label_components(inst, ica)
assert labels.shape == (ica.n_components_, 7)

Expand Down
2 changes: 1 addition & 1 deletion mne_icalabel/iclabel/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_loc(file, eeglab_result_file):
when loading the datasets."""
type_ = str(Path(file).stem)[-3:]
inst = reader[type_](file, **kwargs[type_])
rd, th = _mne_to_eeglab_locs(inst)
rd, th = _mne_to_eeglab_locs(inst, picks=inst.ch_names)
eeglab_loc = loadmat(eeglab_result_file)["loc"][0, 0]
eeglab_rd = eeglab_loc["rd"]
eeglab_th = eeglab_loc["th"]
Expand Down
6 changes: 4 additions & 2 deletions mne_icalabel/iclabel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from numpy.typing import ArrayLike, NDArray


def _mne_to_eeglab_locs(raw: BaseRaw) -> Tuple[NDArray[float], NDArray[float]]:
def _mne_to_eeglab_locs(raw: BaseRaw, picks: List[str]) -> Tuple[NDArray[float], NDArray[float]]:
"""Obtain EEGLab-like spherical coordinate from EEG channel positions.
TODO: @JACOB:
Expand All @@ -18,6 +18,8 @@ def _mne_to_eeglab_locs(raw: BaseRaw) -> Tuple[NDArray[float], NDArray[float]]:
raw : mne.io.BaseRaw
Instance of raw object with a `mne.montage.DigMontage` set with
``n_channels`` channel positions.
picks : list of str
List of channel names to include.
Returns
-------
Expand All @@ -44,7 +46,7 @@ def _cart2sph(_x, _y, _z):
return azimuth, elevation, r

# get the channel position dictionary
montage = raw.get_montage()
montage = raw.copy().pick_channels(picks, ordered=True).get_montage()
positions = montage.get_positions()
ch_pos = positions["ch_pos"]

Expand Down
8 changes: 4 additions & 4 deletions mne_icalabel/tests/test_label_components.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import pytest
from mne.datasets import sample
from mne.datasets import testing
from mne.io import read_raw
from mne.preprocessing import ICA

from mne_icalabel import label_components

directory = sample.data_path() / "MEG" / "sample"
raw = read_raw(directory / "sample_audvis_raw.fif", preload=False)
raw.crop(0, 10).pick_types(eeg=True, exclude="bads")
directory = testing.data_path() / "MEG" / "sample"
raw = read_raw(directory / "sample_audvis_trunc_raw.fif", preload=False)
raw.pick_types(eeg=True, exclude="bads")
raw.load_data()
# preprocess
raw.filter(l_freq=1.0, h_freq=100.0)
Expand Down

0 comments on commit ab090dd

Please sign in to comment.