From d4b3da5ac64562812392e8807fee6f9d550f2e77 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Sat, 6 Jan 2024 16:57:45 +0100 Subject: [PATCH] add tests and fix handling of immutable case --- mne/channels/layout.py | 14 ++++++++++++-- mne/channels/tests/test_layout.py | 27 ++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/mne/channels/layout.py b/mne/channels/layout.py index 06b3386cf5a..d5ba2b9567f 100644 --- a/mne/channels/layout.py +++ b/mne/channels/layout.py @@ -184,7 +184,7 @@ def pick(self, picks=None, exclude=(), *, verbose=None): try: picks = [_ensure_int(picks)] except TypeError: - picks = list(deepcopy(picks)) + picks = list(picks) if isinstance(picks, tuple) else deepcopy(picks) apply_exclude = False if apply_exclude: if isinstance(exclude, str): @@ -193,7 +193,11 @@ def pick(self, picks=None, exclude=(), *, verbose=None): try: exclude = [_ensure_int(exclude)] except TypeError: - exclude = list(deepcopy(exclude)) + exclude = ( + list(exclude) + if isinstance(exclude, tuple) + else deepcopy(exclude) + ) for var, var_name in ((picks, "picks"), (exclude, "exclude")): if var_name == "exclude" and not apply_exclude: continue @@ -230,6 +234,12 @@ def pick(self, picks=None, exclude=(), *, verbose=None): ) if apply_exclude: picks = np.array(list(set(picks) - set(exclude))) + if len(picks) == 0: + raise RuntimeError( + "The channel selection yielded no remaining channels. Please edit " + "the arguments 'picks' and 'exclude' to include at least one " + "channel." + ) else: picks = np.array(list(set(picks))) self.pos = self.pos[picks] diff --git a/mne/channels/tests/test_layout.py b/mne/channels/tests/test_layout.py index e065b687e2d..57eb98f725b 100644 --- a/mne/channels/tests/test_layout.py +++ b/mne/channels/tests/test_layout.py @@ -472,4 +472,29 @@ def test_layout_pick_more(layout): def test_layout_pick_errors(layout): """Test validation of layout.pick.""" - pass + with pytest.raises(TypeError, match="must be a list, tuple, set or ndarray"): + layout.pick(lambda x: x) + with pytest.raises(TypeError, match="must be a list, tuple, set or ndarray"): + layout.pick(None, lambda x: x) + with pytest.raises(TypeError, match="must be integers or strings"): + layout.pick([0, lambda x: x]) + with pytest.raises(TypeError, match="must be integers or strings"): + layout.pick(None, [0, lambda x: x]) + with pytest.raises(ValueError, match="does not match any channels"): + layout.pick("foo") + with pytest.raises(ValueError, match="does not match any channels"): + layout.pick(None, "foo") + with pytest.raises(ValueError, match="does not match any channels"): + layout.pick(101) + with pytest.raises(ValueError, match="does not match any channels"): + layout.pick(None, 101) + with pytest.warns(RuntimeWarning, match="has duplicates which will be ignored"): + layout.copy().pick(["0", "0"]) + with pytest.warns(RuntimeWarning, match="has duplicates which will be ignored"): + layout.copy().pick(["0", 0]) + with pytest.warns(RuntimeWarning, match="has duplicates which will be ignored"): + layout.copy().pick(None, ["0", "0"]) + with pytest.warns(RuntimeWarning, match="has duplicates which will be ignored"): + layout.copy().pick(None, ["0", 0]) + with pytest.raises(RuntimeError, match="selection yielded no remaining channels"): + layout.copy().pick(None, ["0", "1", "2"])