Skip to content

Commit

Permalink
[MRG] Add copy and channel selection for a Layout object (#12338)
Browse files Browse the repository at this point in the history
  • Loading branch information
mscheltienne authored May 30, 2024
1 parent 9a222ba commit d6a58cb
Show file tree
Hide file tree
Showing 7 changed files with 282 additions and 33 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/12338.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Adding :meth:`mne.channels.Layout.copy` and :meth:`mne.channels.Layout.pick` to copy and select channels from a :class:`mne.channels.Layout` object. Plotting 2D topographies of evoked responses with :func:`mne.viz.plot_evoked_topo` with both arguments ``layout`` and ``exclude`` now ignores excluded channels from the :class:`mne.channels.Layout`. By `Mathieu Scheltienne`_.
2 changes: 1 addition & 1 deletion doc/development/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ Describe your changes in the changelog

Include in your changeset a brief description of the change in the
:ref:`changelog <whats_new>` using towncrier_ format, which aggregates small,
properly-named ``.rst`` files to create a change log. This can be
properly-named ``.rst`` files to create a changelog. This can be
skipped for very minor changes like correcting typos in the documentation.

There are six separate sections for changes, based on change type.
Expand Down
140 changes: 136 additions & 4 deletions mne/channels/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import logging
from collections import defaultdict
from copy import deepcopy
from itertools import combinations
from pathlib import Path

Expand All @@ -28,8 +29,10 @@
_check_option,
_check_sphere,
_clean_names,
_ensure_int,
fill_doc,
logger,
verbose,
warn,
)
from ..viz.topomap import plot_layout
Expand All @@ -50,9 +53,9 @@ class Layout:
pos : array, shape=(n_channels, 4)
The unit-normalized positions of the channels in 2d
(x, y, width, height).
names : list
names : list of str
The channel names.
ids : list
ids : array-like of int
The channel ids.
kind : str
The type of Layout (e.g. 'Vectorview-all').
Expand All @@ -62,9 +65,25 @@ def __init__(self, box, pos, names, ids, kind):
self.box = box
self.pos = pos
self.names = names
self.ids = ids
self.ids = np.array(ids)
if self.ids.ndim != 1:
raise ValueError("The channel indices should be a 1D array-like.")
self.kind = kind

def copy(self):
"""Return a copy of the layout.
Returns
-------
layout : instance of Layout
A deepcopy of the layout.
Notes
-----
.. versionadded:: 1.7
"""
return deepcopy(self)

def save(self, fname, overwrite=False):
"""Save Layout to disk.
Expand Down Expand Up @@ -135,6 +154,119 @@ def plot(self, picks=None, show_axes=False, show=True):
"""
return plot_layout(self, picks=picks, show_axes=show_axes, show=show)

@verbose
def pick(self, picks=None, exclude=(), *, verbose=None):
"""Pick a subset of channels.
Parameters
----------
%(picks_layout)s
exclude : str | int | array-like of str or int
Set of channels to exclude, only used when ``picks`` is set to ``'all'`` or
``None``. Exclude will not drop channels explicitly provided in ``picks``.
%(verbose)s
Returns
-------
layout : instance of Layout
The modified layout.
Notes
-----
.. versionadded:: 1.7
"""
# TODO: all the picking functions operates on an 'info' object which is missing
# for a layout, thus we have to do the extra work here. The logic below can be
# replaced when https://github.com/mne-tools/mne-python/issues/11913 is solved.
if (isinstance(picks, str) and picks == "all") or (picks is None):
picks = deepcopy(self.names)
apply_exclude = True
elif isinstance(picks, str):
picks = [picks]
apply_exclude = False
elif isinstance(picks, slice):
try:
picks = np.arange(len(self.names))[picks]
except TypeError:
raise TypeError(
"If a slice is provided, it must be a slice of integers."
)
apply_exclude = False
else:
try:
picks = [_ensure_int(picks)]
except TypeError:
picks = (
list(picks) if isinstance(picks, (tuple, set)) else deepcopy(picks)
)
apply_exclude = False
if apply_exclude:
if isinstance(exclude, str):
exclude = [exclude]
else:
try:
exclude = [_ensure_int(exclude)]
except TypeError:
exclude = (
list(exclude)
if isinstance(exclude, (tuple, set))
else deepcopy(exclude)
)
for var, var_name in ((picks, "picks"), (exclude, "exclude")):
if var_name == "exclude" and not apply_exclude:
continue
if not isinstance(var, (list, tuple, set, np.ndarray)):
raise TypeError(
f"'{var_name}' must be a list, tuple, set or ndarray. "
f"Got {type(var)} instead."
)
if isinstance(var, np.ndarray) and var.ndim != 1:
raise ValueError(
f"'{var_name}' must be a 1D array-like. Got {var.ndim}D instead."
)
for k, elt in enumerate(var):
if isinstance(elt, str) and elt in self.names:
var[k] = self.names.index(elt)
continue
elif isinstance(elt, str):
raise ValueError(
f"The channel name {elt} provided in {var_name} does not match "
"any channels from the layout."
)
try:
var[k] = _ensure_int(elt)
except TypeError:
raise TypeError(
f"All elements in '{var_name}' must be integers or strings."
)
if not (0 <= var[k] < len(self.names)):
raise ValueError(
f"The value {elt} provided in {var_name} does not match any "
f"channels from the layout. The layout has {len(self.names)} "
"channels."
)
if len(var) != len(set(var)):
warn(
f"The provided '{var_name}' has duplicates which will be ignored.",
RuntimeWarning,
)
picks = picks.astype(int) if isinstance(picks, np.ndarray) else picks
exclude = exclude.astype(int) if isinstance(exclude, np.ndarray) else exclude
if apply_exclude:
picks = np.array(list(set(picks) - set(exclude)), dtype=int)
if len(picks) == 0:
raise RuntimeError(
"The channel selection yielded no remaining channels. Please edit "
"the arguments 'picks' and 'exclude' to include at least one "
"channel."
)
else:
picks = np.array(list(set(picks)), dtype=int)
self.pos = self.pos[picks]
self.ids = self.ids[picks]
self.names = [self.names[k] for k in picks]
return self


def _read_lout(fname):
"""Aux function."""
Expand Down Expand Up @@ -533,7 +665,7 @@ def find_layout(info, ch_type=None, exclude="bads"):
idx = [ii for ii, name in enumerate(layout.names) if name not in exclude]
layout.names = [layout.names[ii] for ii in idx]
layout.pos = layout.pos[idx]
layout.ids = [layout.ids[ii] for ii in idx]
layout.ids = layout.ids[idx]

return layout

Expand Down
128 changes: 116 additions & 12 deletions mne/channels/tests/test_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from mne._fiff.constants import FIFF
from mne._fiff.meas_info import _empty_info
from mne.channels import (
Layout,
find_layout,
make_eeg_layout,
make_grid_layout,
Expand Down Expand Up @@ -94,6 +95,18 @@ def _get_test_info():
return test_info


@pytest.fixture(scope="module")
def layout():
"""Get a layout."""
return Layout(
(0.1, 0.2, 0.1, 1.2),
pos=np.array([[0, 0, 0.1, 0.1], [0.2, 0.2, 0.1, 0.1], [0.4, 0.4, 0.1, 0.1]]),
names=["0", "1", "2"],
ids=[0, 1, 2],
kind="test",
)


def test_io_layout_lout(tmp_path):
"""Test IO with .lout files."""
layout = read_layout(fname="Vectorview-all", scale=False)
Expand Down Expand Up @@ -224,23 +237,17 @@ def test_make_grid_layout(tmp_path):

def test_find_layout():
"""Test finding layout."""
pytest.raises(ValueError, find_layout, _get_test_info(), ch_type="meep")
with pytest.raises(ValueError, match="Invalid value for the 'ch_type'"):
find_layout(_get_test_info(), ch_type="meep")

sample_info = read_info(fif_fname)
grads = pick_types(sample_info, meg="grad")
sample_info2 = pick_info(sample_info, grads)

mags = pick_types(sample_info, meg="mag")
sample_info3 = pick_info(sample_info, mags)

# mock new convention
sample_info2 = pick_info(sample_info, pick_types(sample_info, meg="grad"))
sample_info3 = pick_info(sample_info, pick_types(sample_info, meg="mag"))
sample_info4 = copy.deepcopy(sample_info)
for ii, name in enumerate(sample_info4["ch_names"]):
for ii, name in enumerate(sample_info4["ch_names"]): # mock new convention
new = name.replace(" ", "")
sample_info4["chs"][ii]["ch_name"] = new

eegs = pick_types(sample_info, meg=False, eeg=True)
sample_info5 = pick_info(sample_info, eegs)
sample_info5 = pick_info(sample_info, pick_types(sample_info, meg=False, eeg=True))

lout = find_layout(sample_info, ch_type=None)
assert lout.kind == "Vectorview-all"
Expand Down Expand Up @@ -404,3 +411,100 @@ def test_generate_2d_layout():
# Make sure background image normalizing is correct
lt_bg = generate_2d_layout(xy, bg_image=bg_image)
assert_allclose(lt_bg.pos[:, :2].max(), xy.max() / float(sbg))


def test_layout_copy(layout):
"""Test copying a layout."""
layout2 = layout.copy()
assert_allclose(layout.pos, layout2.pos)
assert layout.names == layout2.names
layout2.names[0] = "foo"
layout2.pos[0, 0] = 0.8
assert layout.names != layout2.names
assert layout.pos[0, 0] != layout2.pos[0, 0]


@pytest.mark.parametrize(
"picks, exclude",
[
([0, 1], ()),
(["0", 1], ()),
(None, ["2"]),
(None, "2"),
(None, [2]),
(None, 2),
("all", 2),
("all", "2"),
(slice(0, 2), ()),
(("0", "1"), ("0", "1")),
(("0", 1), ("0", "1")),
(("0", 1), (0, "1")),
(set(["0", 1]), ()),
(set([0, 1]), set()),
(None, set([2])),
(np.array([0, 1]), ()),
(None, np.array([2])),
(np.array(["0", "1"]), ()),
],
)
def test_layout_pick(layout, picks, exclude):
"""Test selection of channels in a layout."""
layout2 = layout.copy()
layout2.pick(picks, exclude)
assert layout2.names == layout.names[:2]
assert_allclose(layout2.pos, layout.pos[:2, :])


def test_layout_pick_more(layout):
"""Test more channel selection in a layout."""
layout2 = layout.copy()
layout2.pick(0)
assert len(layout2.names) == 1
assert layout2.names[0] == layout.names[0]
assert_allclose(layout2.pos, layout.pos[:1, :])

layout2 = layout.copy()
layout2.pick("all", exclude=("0", "1"))
assert len(layout2.names) == 1
assert layout2.names[0] == layout.names[2]
assert_allclose(layout2.pos, layout.pos[2:, :])

layout2 = layout.copy()
layout2.pick("all", exclude=("0", 1))
assert len(layout2.names) == 1
assert layout2.names[0] == layout.names[2]
assert_allclose(layout2.pos, layout.pos[2:, :])


def test_layout_pick_errors(layout):
"""Test validation of layout.pick."""
with pytest.raises(TypeError, match="must be a list, tuple, set or ndarray"):
layout.pick(lambda x: x)
with pytest.raises(TypeError, match="must be a list, tuple, set or ndarray"):
layout.pick(None, lambda x: x)
with pytest.raises(TypeError, match="must be integers or strings"):
layout.pick([0, lambda x: x])
with pytest.raises(TypeError, match="must be integers or strings"):
layout.pick(None, [0, lambda x: x])
with pytest.raises(ValueError, match="does not match any channels"):
layout.pick("foo")
with pytest.raises(ValueError, match="does not match any channels"):
layout.pick(None, "foo")
with pytest.raises(ValueError, match="does not match any channels"):
layout.pick(101)
with pytest.raises(ValueError, match="does not match any channels"):
layout.pick(None, 101)
with pytest.warns(RuntimeWarning, match="has duplicates which will be ignored"):
layout.copy().pick(["0", "0"])
with pytest.warns(RuntimeWarning, match="has duplicates which will be ignored"):
layout.copy().pick(["0", 0])
with pytest.warns(RuntimeWarning, match="has duplicates which will be ignored"):
layout.copy().pick(None, ["0", "0"])
with pytest.warns(RuntimeWarning, match="has duplicates which will be ignored"):
layout.copy().pick(None, ["0", 0])
with pytest.raises(RuntimeError, match="selection yielded no remaining channels"):
layout.copy().pick(None, ["0", "1", "2"])
with pytest.raises(ValueError, match="must be a 1D array-like"):
layout.copy().pick(None, np.array([[0, 1]]))
with pytest.raises(TypeError, match="slice of integers"):
layout.copy().pick(slice("2342342342", 0, 3), ())
Loading

0 comments on commit d6a58cb

Please sign in to comment.