Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Add copy and channel selection for a Layout object #12338

Merged
merged 15 commits into from
May 30, 2024
Merged
1 change: 1 addition & 0 deletions doc/changes/devel/12338.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Adding :meth:`mne.channels.Layout.copy` and :meth:`mne.channels.Layout.pick` to copy and select channels from a :class:`mne.channels.Layout` object. Plotting 2D topographies of evoked responses with :func:`mne.viz.plot_evoked_topo` with both arguments ``layout`` and ``exclude`` now ignores excluded channels from the :class:`mne.channels.Layout`. By `Mathieu Scheltienne`_.
2 changes: 1 addition & 1 deletion doc/development/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ Describe your changes in the changelog

Include in your changeset a brief description of the change in the
:ref:`changelog <whats_new>` using towncrier_ format, which aggregates small,
properly-named ``.rst`` files to create a change log. This can be
properly-named ``.rst`` files to create a changelog. This can be
skipped for very minor changes like correcting typos in the documentation.

There are six separate sections for changes, based on change type.
Expand Down
140 changes: 136 additions & 4 deletions mne/channels/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import logging
from collections import defaultdict
from copy import deepcopy
from itertools import combinations
from pathlib import Path

Expand All @@ -28,8 +29,10 @@
_check_option,
_check_sphere,
_clean_names,
_ensure_int,
fill_doc,
logger,
verbose,
warn,
)
from ..viz.topomap import plot_layout
Expand All @@ -50,9 +53,9 @@ class Layout:
pos : array, shape=(n_channels, 4)
The unit-normalized positions of the channels in 2d
(x, y, width, height).
names : list
names : list of str
The channel names.
ids : list
ids : array-like of int
The channel ids.
kind : str
The type of Layout (e.g. 'Vectorview-all').
Expand All @@ -62,9 +65,25 @@ def __init__(self, box, pos, names, ids, kind):
self.box = box
self.pos = pos
self.names = names
self.ids = ids
self.ids = np.array(ids)
if self.ids.ndim != 1:
raise ValueError("The channel indices should be a 1D array-like.")
self.kind = kind

def copy(self):
"""Return a copy of the layout.

Returns
-------
layout : instance of Layout
A deepcopy of the layout.

Notes
-----
.. versionadded:: 1.7
"""
return deepcopy(self)

def save(self, fname, overwrite=False):
"""Save Layout to disk.

Expand Down Expand Up @@ -135,6 +154,119 @@ def plot(self, picks=None, show_axes=False, show=True):
"""
return plot_layout(self, picks=picks, show_axes=show_axes, show=show)

@verbose
def pick(self, picks=None, exclude=(), *, verbose=None):
"""Pick a subset of channels.

Parameters
----------
%(picks_layout)s
exclude : str | int | array-like of str or int
Set of channels to exclude, only used when ``picks`` is set to ``'all'`` or
``None``. Exclude will not drop channels explicitly provided in ``picks``.
%(verbose)s

Returns
-------
layout : instance of Layout
The modified layout.

Notes
-----
.. versionadded:: 1.7
"""
# TODO: all the picking functions operates on an 'info' object which is missing
# for a layout, thus we have to do the extra work here. The logic below can be
# replaced when https://github.com/mne-tools/mne-python/issues/11913 is solved.
if (isinstance(picks, str) and picks == "all") or (picks is None):
picks = deepcopy(self.names)
apply_exclude = True
elif isinstance(picks, str):
picks = [picks]
apply_exclude = False
elif isinstance(picks, slice):
try:
picks = np.arange(len(self.names))[picks]
except TypeError:
raise TypeError(
"If a slice is provided, it must be a slice of integers."
)
apply_exclude = False
else:
try:
picks = [_ensure_int(picks)]
except TypeError:
picks = (
list(picks) if isinstance(picks, (tuple, set)) else deepcopy(picks)
)
apply_exclude = False
if apply_exclude:
if isinstance(exclude, str):
exclude = [exclude]
else:
try:
exclude = [_ensure_int(exclude)]
except TypeError:
exclude = (
list(exclude)
if isinstance(exclude, (tuple, set))
else deepcopy(exclude)
)
for var, var_name in ((picks, "picks"), (exclude, "exclude")):
if var_name == "exclude" and not apply_exclude:
continue
if not isinstance(var, (list, tuple, set, np.ndarray)):
raise TypeError(
f"'{var_name}' must be a list, tuple, set or ndarray. "
f"Got {type(var)} instead."
)
if isinstance(var, np.ndarray) and var.ndim != 1:
raise ValueError(
f"'{var_name}' must be a 1D array-like. Got {var.ndim}D instead."
)
for k, elt in enumerate(var):
if isinstance(elt, str) and elt in self.names:
var[k] = self.names.index(elt)
continue
elif isinstance(elt, str):
raise ValueError(
f"The channel name {elt} provided in {var_name} does not match "
"any channels from the layout."
)
try:
var[k] = _ensure_int(elt)
except TypeError:
raise TypeError(
f"All elements in '{var_name}' must be integers or strings."
)
if not (0 <= var[k] < len(self.names)):
raise ValueError(
f"The value {elt} provided in {var_name} does not match any "
f"channels from the layout. The layout has {len(self.names)} "
"channels."
)
if len(var) != len(set(var)):
warn(
f"The provided '{var_name}' has duplicates which will be ignored.",
RuntimeWarning,
)
picks = picks.astype(int) if isinstance(picks, np.ndarray) else picks
exclude = exclude.astype(int) if isinstance(exclude, np.ndarray) else exclude
if apply_exclude:
picks = np.array(list(set(picks) - set(exclude)), dtype=int)
if len(picks) == 0:
raise RuntimeError(
"The channel selection yielded no remaining channels. Please edit "
"the arguments 'picks' and 'exclude' to include at least one "
"channel."
)
else:
picks = np.array(list(set(picks)), dtype=int)
self.pos = self.pos[picks]
self.ids = self.ids[picks]
self.names = [self.names[k] for k in picks]
return self
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you check if this logic is not present elsewhere? I would be surprised if it's not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the only oneI found that would work is picks = _picks_to_idx(len(layout.names), picks) which picks on a number of channels. It's more limited and restrictive on the inputs than the logic above.
I do plan to propose a new channel selection API, to try to 1. clean-up all the redundant code and multiple pick functions in mne._fiff.pick and 2. open a public API for channel selection (#11913). Hopefully next week 😉



def _read_lout(fname):
"""Aux function."""
Expand Down Expand Up @@ -533,7 +665,7 @@ def find_layout(info, ch_type=None, exclude="bads"):
idx = [ii for ii, name in enumerate(layout.names) if name not in exclude]
layout.names = [layout.names[ii] for ii in idx]
layout.pos = layout.pos[idx]
layout.ids = [layout.ids[ii] for ii in idx]
layout.ids = layout.ids[idx]

return layout

Expand Down
128 changes: 116 additions & 12 deletions mne/channels/tests/test_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from mne._fiff.constants import FIFF
from mne._fiff.meas_info import _empty_info
from mne.channels import (
Layout,
find_layout,
make_eeg_layout,
make_grid_layout,
Expand Down Expand Up @@ -94,6 +95,18 @@ def _get_test_info():
return test_info


@pytest.fixture(scope="module")
def layout():
"""Get a layout."""
return Layout(
(0.1, 0.2, 0.1, 1.2),
pos=np.array([[0, 0, 0.1, 0.1], [0.2, 0.2, 0.1, 0.1], [0.4, 0.4, 0.1, 0.1]]),
names=["0", "1", "2"],
ids=[0, 1, 2],
kind="test",
)


def test_io_layout_lout(tmp_path):
"""Test IO with .lout files."""
layout = read_layout(fname="Vectorview-all", scale=False)
Expand Down Expand Up @@ -224,23 +237,17 @@ def test_make_grid_layout(tmp_path):

def test_find_layout():
"""Test finding layout."""
pytest.raises(ValueError, find_layout, _get_test_info(), ch_type="meep")
with pytest.raises(ValueError, match="Invalid value for the 'ch_type'"):
find_layout(_get_test_info(), ch_type="meep")

sample_info = read_info(fif_fname)
grads = pick_types(sample_info, meg="grad")
sample_info2 = pick_info(sample_info, grads)

mags = pick_types(sample_info, meg="mag")
sample_info3 = pick_info(sample_info, mags)

# mock new convention
sample_info2 = pick_info(sample_info, pick_types(sample_info, meg="grad"))
sample_info3 = pick_info(sample_info, pick_types(sample_info, meg="mag"))
sample_info4 = copy.deepcopy(sample_info)
for ii, name in enumerate(sample_info4["ch_names"]):
for ii, name in enumerate(sample_info4["ch_names"]): # mock new convention
new = name.replace(" ", "")
sample_info4["chs"][ii]["ch_name"] = new

eegs = pick_types(sample_info, meg=False, eeg=True)
sample_info5 = pick_info(sample_info, eegs)
sample_info5 = pick_info(sample_info, pick_types(sample_info, meg=False, eeg=True))

lout = find_layout(sample_info, ch_type=None)
assert lout.kind == "Vectorview-all"
Expand Down Expand Up @@ -404,3 +411,100 @@ def test_generate_2d_layout():
# Make sure background image normalizing is correct
lt_bg = generate_2d_layout(xy, bg_image=bg_image)
assert_allclose(lt_bg.pos[:, :2].max(), xy.max() / float(sbg))


def test_layout_copy(layout):
"""Test copying a layout."""
layout2 = layout.copy()
assert_allclose(layout.pos, layout2.pos)
assert layout.names == layout2.names
layout2.names[0] = "foo"
layout2.pos[0, 0] = 0.8
assert layout.names != layout2.names
assert layout.pos[0, 0] != layout2.pos[0, 0]


@pytest.mark.parametrize(
"picks, exclude",
[
([0, 1], ()),
(["0", 1], ()),
(None, ["2"]),
(None, "2"),
(None, [2]),
(None, 2),
("all", 2),
("all", "2"),
(slice(0, 2), ()),
(("0", "1"), ("0", "1")),
(("0", 1), ("0", "1")),
(("0", 1), (0, "1")),
(set(["0", 1]), ()),
(set([0, 1]), set()),
(None, set([2])),
(np.array([0, 1]), ()),
(None, np.array([2])),
(np.array(["0", "1"]), ()),
],
)
def test_layout_pick(layout, picks, exclude):
"""Test selection of channels in a layout."""
layout2 = layout.copy()
layout2.pick(picks, exclude)
assert layout2.names == layout.names[:2]
assert_allclose(layout2.pos, layout.pos[:2, :])


def test_layout_pick_more(layout):
"""Test more channel selection in a layout."""
layout2 = layout.copy()
layout2.pick(0)
assert len(layout2.names) == 1
assert layout2.names[0] == layout.names[0]
assert_allclose(layout2.pos, layout.pos[:1, :])

layout2 = layout.copy()
layout2.pick("all", exclude=("0", "1"))
assert len(layout2.names) == 1
assert layout2.names[0] == layout.names[2]
assert_allclose(layout2.pos, layout.pos[2:, :])

layout2 = layout.copy()
layout2.pick("all", exclude=("0", 1))
assert len(layout2.names) == 1
assert layout2.names[0] == layout.names[2]
assert_allclose(layout2.pos, layout.pos[2:, :])


def test_layout_pick_errors(layout):
"""Test validation of layout.pick."""
with pytest.raises(TypeError, match="must be a list, tuple, set or ndarray"):
layout.pick(lambda x: x)
with pytest.raises(TypeError, match="must be a list, tuple, set or ndarray"):
layout.pick(None, lambda x: x)
with pytest.raises(TypeError, match="must be integers or strings"):
layout.pick([0, lambda x: x])
with pytest.raises(TypeError, match="must be integers or strings"):
layout.pick(None, [0, lambda x: x])
with pytest.raises(ValueError, match="does not match any channels"):
layout.pick("foo")
with pytest.raises(ValueError, match="does not match any channels"):
layout.pick(None, "foo")
with pytest.raises(ValueError, match="does not match any channels"):
layout.pick(101)
with pytest.raises(ValueError, match="does not match any channels"):
layout.pick(None, 101)
with pytest.warns(RuntimeWarning, match="has duplicates which will be ignored"):
layout.copy().pick(["0", "0"])
with pytest.warns(RuntimeWarning, match="has duplicates which will be ignored"):
layout.copy().pick(["0", 0])
with pytest.warns(RuntimeWarning, match="has duplicates which will be ignored"):
layout.copy().pick(None, ["0", "0"])
with pytest.warns(RuntimeWarning, match="has duplicates which will be ignored"):
layout.copy().pick(None, ["0", 0])
with pytest.raises(RuntimeError, match="selection yielded no remaining channels"):
layout.copy().pick(None, ["0", "1", "2"])
with pytest.raises(ValueError, match="must be a 1D array-like"):
layout.copy().pick(None, np.array([[0, 1]]))
with pytest.raises(TypeError, match="slice of integers"):
layout.copy().pick(slice("2342342342", 0, 3), ())
Loading
Loading