From d6a58cb6ef9a8d69143c4629c3991ccbf250af42 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Thu, 30 May 2024 16:41:11 +0200 Subject: [PATCH] [MRG] Add copy and channel selection for a Layout object (#12338) --- doc/changes/devel/12338.newfeature.rst | 1 + doc/development/contributing.rst | 2 +- mne/channels/layout.py | 140 ++++++++++++++++++++++++- mne/channels/tests/test_layout.py | 128 +++++++++++++++++++--- mne/utils/docs.py | 24 +++-- mne/viz/topo.py | 12 ++- mne/viz/topomap.py | 8 +- 7 files changed, 282 insertions(+), 33 deletions(-) create mode 100644 doc/changes/devel/12338.newfeature.rst diff --git a/doc/changes/devel/12338.newfeature.rst b/doc/changes/devel/12338.newfeature.rst new file mode 100644 index 00000000000..899884d8b61 --- /dev/null +++ b/doc/changes/devel/12338.newfeature.rst @@ -0,0 +1 @@ +Adding :meth:`mne.channels.Layout.copy` and :meth:`mne.channels.Layout.pick` to copy and select channels from a :class:`mne.channels.Layout` object. Plotting 2D topographies of evoked responses with :func:`mne.viz.plot_evoked_topo` with both arguments ``layout`` and ``exclude`` now ignores excluded channels from the :class:`mne.channels.Layout`. By `Mathieu Scheltienne`_. diff --git a/doc/development/contributing.rst b/doc/development/contributing.rst index 04fa49e924b..aec32da00ac 100644 --- a/doc/development/contributing.rst +++ b/doc/development/contributing.rst @@ -592,7 +592,7 @@ Describe your changes in the changelog Include in your changeset a brief description of the change in the :ref:`changelog ` using towncrier_ format, which aggregates small, -properly-named ``.rst`` files to create a change log. This can be +properly-named ``.rst`` files to create a changelog. This can be skipped for very minor changes like correcting typos in the documentation. There are six separate sections for changes, based on change type. diff --git a/mne/channels/layout.py b/mne/channels/layout.py index 83df9f377d7..bcfefe5cf2c 100644 --- a/mne/channels/layout.py +++ b/mne/channels/layout.py @@ -12,6 +12,7 @@ import logging from collections import defaultdict +from copy import deepcopy from itertools import combinations from pathlib import Path @@ -28,8 +29,10 @@ _check_option, _check_sphere, _clean_names, + _ensure_int, fill_doc, logger, + verbose, warn, ) from ..viz.topomap import plot_layout @@ -50,9 +53,9 @@ class Layout: pos : array, shape=(n_channels, 4) The unit-normalized positions of the channels in 2d (x, y, width, height). - names : list + names : list of str The channel names. - ids : list + ids : array-like of int The channel ids. kind : str The type of Layout (e.g. 'Vectorview-all'). @@ -62,9 +65,25 @@ def __init__(self, box, pos, names, ids, kind): self.box = box self.pos = pos self.names = names - self.ids = ids + self.ids = np.array(ids) + if self.ids.ndim != 1: + raise ValueError("The channel indices should be a 1D array-like.") self.kind = kind + def copy(self): + """Return a copy of the layout. + + Returns + ------- + layout : instance of Layout + A deepcopy of the layout. + + Notes + ----- + .. versionadded:: 1.7 + """ + return deepcopy(self) + def save(self, fname, overwrite=False): """Save Layout to disk. @@ -135,6 +154,119 @@ def plot(self, picks=None, show_axes=False, show=True): """ return plot_layout(self, picks=picks, show_axes=show_axes, show=show) + @verbose + def pick(self, picks=None, exclude=(), *, verbose=None): + """Pick a subset of channels. + + Parameters + ---------- + %(picks_layout)s + exclude : str | int | array-like of str or int + Set of channels to exclude, only used when ``picks`` is set to ``'all'`` or + ``None``. Exclude will not drop channels explicitly provided in ``picks``. + %(verbose)s + + Returns + ------- + layout : instance of Layout + The modified layout. + + Notes + ----- + .. versionadded:: 1.7 + """ + # TODO: all the picking functions operates on an 'info' object which is missing + # for a layout, thus we have to do the extra work here. The logic below can be + # replaced when https://github.com/mne-tools/mne-python/issues/11913 is solved. + if (isinstance(picks, str) and picks == "all") or (picks is None): + picks = deepcopy(self.names) + apply_exclude = True + elif isinstance(picks, str): + picks = [picks] + apply_exclude = False + elif isinstance(picks, slice): + try: + picks = np.arange(len(self.names))[picks] + except TypeError: + raise TypeError( + "If a slice is provided, it must be a slice of integers." + ) + apply_exclude = False + else: + try: + picks = [_ensure_int(picks)] + except TypeError: + picks = ( + list(picks) if isinstance(picks, (tuple, set)) else deepcopy(picks) + ) + apply_exclude = False + if apply_exclude: + if isinstance(exclude, str): + exclude = [exclude] + else: + try: + exclude = [_ensure_int(exclude)] + except TypeError: + exclude = ( + list(exclude) + if isinstance(exclude, (tuple, set)) + else deepcopy(exclude) + ) + for var, var_name in ((picks, "picks"), (exclude, "exclude")): + if var_name == "exclude" and not apply_exclude: + continue + if not isinstance(var, (list, tuple, set, np.ndarray)): + raise TypeError( + f"'{var_name}' must be a list, tuple, set or ndarray. " + f"Got {type(var)} instead." + ) + if isinstance(var, np.ndarray) and var.ndim != 1: + raise ValueError( + f"'{var_name}' must be a 1D array-like. Got {var.ndim}D instead." + ) + for k, elt in enumerate(var): + if isinstance(elt, str) and elt in self.names: + var[k] = self.names.index(elt) + continue + elif isinstance(elt, str): + raise ValueError( + f"The channel name {elt} provided in {var_name} does not match " + "any channels from the layout." + ) + try: + var[k] = _ensure_int(elt) + except TypeError: + raise TypeError( + f"All elements in '{var_name}' must be integers or strings." + ) + if not (0 <= var[k] < len(self.names)): + raise ValueError( + f"The value {elt} provided in {var_name} does not match any " + f"channels from the layout. The layout has {len(self.names)} " + "channels." + ) + if len(var) != len(set(var)): + warn( + f"The provided '{var_name}' has duplicates which will be ignored.", + RuntimeWarning, + ) + picks = picks.astype(int) if isinstance(picks, np.ndarray) else picks + exclude = exclude.astype(int) if isinstance(exclude, np.ndarray) else exclude + if apply_exclude: + picks = np.array(list(set(picks) - set(exclude)), dtype=int) + if len(picks) == 0: + raise RuntimeError( + "The channel selection yielded no remaining channels. Please edit " + "the arguments 'picks' and 'exclude' to include at least one " + "channel." + ) + else: + picks = np.array(list(set(picks)), dtype=int) + self.pos = self.pos[picks] + self.ids = self.ids[picks] + self.names = [self.names[k] for k in picks] + return self + def _read_lout(fname): """Aux function.""" @@ -533,7 +665,7 @@ def find_layout(info, ch_type=None, exclude="bads"): idx = [ii for ii, name in enumerate(layout.names) if name not in exclude] layout.names = [layout.names[ii] for ii in idx] layout.pos = layout.pos[idx] - layout.ids = [layout.ids[ii] for ii in idx] + layout.ids = layout.ids[idx] return layout diff --git a/mne/channels/tests/test_layout.py b/mne/channels/tests/test_layout.py index 15eb50b7975..6dd6bc630be 100644 --- a/mne/channels/tests/test_layout.py +++ b/mne/channels/tests/test_layout.py @@ -23,6 +23,7 @@ from mne._fiff.constants import FIFF from mne._fiff.meas_info import _empty_info from mne.channels import ( + Layout, find_layout, make_eeg_layout, make_grid_layout, @@ -94,6 +95,18 @@ def _get_test_info(): return test_info +@pytest.fixture(scope="module") +def layout(): + """Get a layout.""" + return Layout( + (0.1, 0.2, 0.1, 1.2), + pos=np.array([[0, 0, 0.1, 0.1], [0.2, 0.2, 0.1, 0.1], [0.4, 0.4, 0.1, 0.1]]), + names=["0", "1", "2"], + ids=[0, 1, 2], + kind="test", + ) + + def test_io_layout_lout(tmp_path): """Test IO with .lout files.""" layout = read_layout(fname="Vectorview-all", scale=False) @@ -224,23 +237,17 @@ def test_make_grid_layout(tmp_path): def test_find_layout(): """Test finding layout.""" - pytest.raises(ValueError, find_layout, _get_test_info(), ch_type="meep") + with pytest.raises(ValueError, match="Invalid value for the 'ch_type'"): + find_layout(_get_test_info(), ch_type="meep") sample_info = read_info(fif_fname) - grads = pick_types(sample_info, meg="grad") - sample_info2 = pick_info(sample_info, grads) - - mags = pick_types(sample_info, meg="mag") - sample_info3 = pick_info(sample_info, mags) - - # mock new convention + sample_info2 = pick_info(sample_info, pick_types(sample_info, meg="grad")) + sample_info3 = pick_info(sample_info, pick_types(sample_info, meg="mag")) sample_info4 = copy.deepcopy(sample_info) - for ii, name in enumerate(sample_info4["ch_names"]): + for ii, name in enumerate(sample_info4["ch_names"]): # mock new convention new = name.replace(" ", "") sample_info4["chs"][ii]["ch_name"] = new - - eegs = pick_types(sample_info, meg=False, eeg=True) - sample_info5 = pick_info(sample_info, eegs) + sample_info5 = pick_info(sample_info, pick_types(sample_info, meg=False, eeg=True)) lout = find_layout(sample_info, ch_type=None) assert lout.kind == "Vectorview-all" @@ -404,3 +411,100 @@ def test_generate_2d_layout(): # Make sure background image normalizing is correct lt_bg = generate_2d_layout(xy, bg_image=bg_image) assert_allclose(lt_bg.pos[:, :2].max(), xy.max() / float(sbg)) + + +def test_layout_copy(layout): + """Test copying a layout.""" + layout2 = layout.copy() + assert_allclose(layout.pos, layout2.pos) + assert layout.names == layout2.names + layout2.names[0] = "foo" + layout2.pos[0, 0] = 0.8 + assert layout.names != layout2.names + assert layout.pos[0, 0] != layout2.pos[0, 0] + + +@pytest.mark.parametrize( + "picks, exclude", + [ + ([0, 1], ()), + (["0", 1], ()), + (None, ["2"]), + (None, "2"), + (None, [2]), + (None, 2), + ("all", 2), + ("all", "2"), + (slice(0, 2), ()), + (("0", "1"), ("0", "1")), + (("0", 1), ("0", "1")), + (("0", 1), (0, "1")), + (set(["0", 1]), ()), + (set([0, 1]), set()), + (None, set([2])), + (np.array([0, 1]), ()), + (None, np.array([2])), + (np.array(["0", "1"]), ()), + ], +) +def test_layout_pick(layout, picks, exclude): + """Test selection of channels in a layout.""" + layout2 = layout.copy() + layout2.pick(picks, exclude) + assert layout2.names == layout.names[:2] + assert_allclose(layout2.pos, layout.pos[:2, :]) + + +def test_layout_pick_more(layout): + """Test more channel selection in a layout.""" + layout2 = layout.copy() + layout2.pick(0) + assert len(layout2.names) == 1 + assert layout2.names[0] == layout.names[0] + assert_allclose(layout2.pos, layout.pos[:1, :]) + + layout2 = layout.copy() + layout2.pick("all", exclude=("0", "1")) + assert len(layout2.names) == 1 + assert layout2.names[0] == layout.names[2] + assert_allclose(layout2.pos, layout.pos[2:, :]) + + layout2 = layout.copy() + layout2.pick("all", exclude=("0", 1)) + assert len(layout2.names) == 1 + assert layout2.names[0] == layout.names[2] + assert_allclose(layout2.pos, layout.pos[2:, :]) + + +def test_layout_pick_errors(layout): + """Test validation of layout.pick.""" + with pytest.raises(TypeError, match="must be a list, tuple, set or ndarray"): + layout.pick(lambda x: x) + with pytest.raises(TypeError, match="must be a list, tuple, set or ndarray"): + layout.pick(None, lambda x: x) + with pytest.raises(TypeError, match="must be integers or strings"): + layout.pick([0, lambda x: x]) + with pytest.raises(TypeError, match="must be integers or strings"): + layout.pick(None, [0, lambda x: x]) + with pytest.raises(ValueError, match="does not match any channels"): + layout.pick("foo") + with pytest.raises(ValueError, match="does not match any channels"): + layout.pick(None, "foo") + with pytest.raises(ValueError, match="does not match any channels"): + layout.pick(101) + with pytest.raises(ValueError, match="does not match any channels"): + layout.pick(None, 101) + with pytest.warns(RuntimeWarning, match="has duplicates which will be ignored"): + layout.copy().pick(["0", "0"]) + with pytest.warns(RuntimeWarning, match="has duplicates which will be ignored"): + layout.copy().pick(["0", 0]) + with pytest.warns(RuntimeWarning, match="has duplicates which will be ignored"): + layout.copy().pick(None, ["0", "0"]) + with pytest.warns(RuntimeWarning, match="has duplicates which will be ignored"): + layout.copy().pick(None, ["0", 0]) + with pytest.raises(RuntimeError, match="selection yielded no remaining channels"): + layout.copy().pick(None, ["0", "1", "2"]) + with pytest.raises(ValueError, match="must be a 1D array-like"): + layout.copy().pick(None, np.array([[0, 1]])) + with pytest.raises(TypeError, match="slice of integers"): + layout.copy().pick(slice("2342342342", 0, 3), ()) diff --git a/mne/utils/docs.py b/mne/utils/docs.py index f1747a7f626..9aabafb90b6 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -3358,8 +3358,8 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): pick channels of those types,""" _picks_str_names = """channel *name* strings (e.g., ``['MEG0111', 'MEG2623']`` will pick the given channels.""" -_picks_str_values = """Can also be the string values "all" to pick - all channels, or "data" to pick :term:`data channels`.""" +_picks_str_values = """Can also be the string values ``'all'`` to pick + all channels, or ``'data'`` to pick :term:`data channels`.""" _picks_str = f"""In lists, {_picks_str_types} {_picks_str_names} {_picks_str_values} None (default) will pick""" @@ -3400,8 +3400,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): If an integer, represents the index of the IC to pick. Multiple ICs can be selected using a list of int or a slice. The indices are 0-indexed, so ``picks=1`` will pick the second - IC: ``ICA001``. ``None`` will pick all independent components in the order - fitted. + IC: ``ICA001``. ``None`` will pick all independent components in the order fitted. +""" +docdict["picks_layout"] = """ +picks : array-like of str or int | slice | ``'all'`` | None + Channels to include in the layout. Slices and lists of integers will be interpreted + as channel indices. Can also be the string value ``'all'`` to pick all channels. + None (default) will pick all channels. """ docdict["picks_nostr"] = f"""picks : list | slice | None {_picks_desc} {_picks_int} @@ -3490,7 +3495,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 0.24 .. versionchanged:: 1.0 - Support for the MNE_BROWSER_PRECOMPUTE config variable. + Support for the ``MNE_BROWSER_PRECOMPUTE`` config variable. """ docdict["preload"] = """ @@ -3528,9 +3533,9 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["proj_plot"] = """ proj : bool | 'interactive' | 'reconstruct' - If true SSP projections are applied before display. If 'interactive', + If true SSP projections are applied before display. If ``'interactive'``, a check box for reversible selection of SSP projection vectors will - be shown. If 'reconstruct', projection vectors will be applied and then + be shown. If ``'reconstruct'``, projection vectors will be applied and then M/EEG data will be reconstructed via field mapping to reduce the signal bias caused by projection. @@ -3823,9 +3828,8 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["return_pca_vars_pctf"] = """ return_pca_vars : bool Whether or not to return the explained variances across the specified - vertices for individual SVD components. This is only valid if - mode='svd'. - Default return_pca_vars=False. + vertices for individual SVD components. This is only valid if ``mode='svd'``. + Default to False. """ docdict["roll"] = """ diff --git a/mne/viz/topo.py b/mne/viz/topo.py index 13319cc586c..d43c7610ca3 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -14,7 +14,7 @@ import numpy as np from scipy import ndimage -from .._fiff.pick import channel_type, pick_types +from .._fiff.pick import _picks_to_idx, channel_type, pick_types from ..defaults import _handle_default from ..utils import Bunch, _check_option, _clean_names, _is_numeric, _to_rgb, fill_doc from .utils import ( @@ -974,6 +974,16 @@ def _plot_evoked_topo( if layout is None: layout = find_layout(info, exclude=exclude) + else: + layout = layout.pick( + "all", + exclude=_picks_to_idx( + info, + exclude if exclude != "bads" else info["bads"], + exclude=(), + allow_empty=True, + ), + ) if not merge_channels: # XXX. at the moment we are committed to 1- / 2-sensor-types layouts diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index 45bb167c997..99ead332308 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -2883,7 +2883,7 @@ def plot_layout(layout, picks=None, show_axes=False, show=True): ---------- layout : None | Layout Layout instance specifying sensor positions. - %(picks_nostr)s + %(picks_layout)s show_axes : bool Show layout axes if True. Defaults to False. show : bool @@ -2907,10 +2907,8 @@ def plot_layout(layout, picks=None, show_axes=False, show=True): ax.set(xticks=[], yticks=[], aspect="equal") outlines = dict(border=([0, 1, 1, 0, 0], [0, 0, 1, 1, 0])) _draw_outlines(ax, outlines) - picks = _picks_to_idx(len(layout.names), picks) - pos = layout.pos[picks] - names = np.array(layout.names)[picks] - for ii, (p, ch_id) in enumerate(zip(pos, names)): + layout = layout.copy().pick(picks) + for ii, (p, ch_id) in enumerate(zip(layout.pos, layout.names)): center_pos = np.array((p[0] + p[2] / 2.0, p[1] + p[3] / 2.0)) ax.annotate( ch_id,