diff --git a/mne/epochs.py b/mne/epochs.py index bb61e0cc62c..fdba9d23cc6 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -700,8 +700,9 @@ def _check_consistency(self): assert hasattr(self, "_times_readonly") assert not self.times.flags["WRITEABLE"] assert isinstance(self.drop_log, tuple) + print("self.drop_log", self.drop_log) assert all(isinstance(log, tuple) for log in self.drop_log) - assert all(isinstance(s, str) for log in self.drop_log for s in log) + assert all(isinstance(s, (str,tuple)) for log in self.drop_log for s in log) def reset_drop_log_selection(self): """Reset the drop_log and selection entries. @@ -793,9 +794,14 @@ def _reject_setup(self, reject, flat): reject = deepcopy(reject) if reject is not None else dict() flat = deepcopy(flat) if flat is not None else dict() for rej, kind in zip((reject, flat), ("reject", "flat")): + if not isinstance(rej, dict): + raise TypeError( + "reject and flat must be dict or None, not %s" % type(rej) + ) bads = set(rej.keys()) - set(idx.keys()) if len(bads) > 0: - raise KeyError("Unknown channel types found in %s: %s" % (kind, bads)) + raise KeyError( + "Unknown channel types found in %s: %s" % (kind, bads)) for key in idx.keys(): # don't throw an error if rejection/flat would do nothing @@ -811,40 +817,30 @@ def _reject_setup(self, reject, flat): ) # check for invalid values - for rej, kind in zip((reject, flat), ("Rejection", "Flat")): - if not isinstance(rej, dict): - raise TypeError( - "reject and flat must be dict or None, not %s" % type(rej) - ) - - # Check if each reject/flat dict is a tuple that contains a - # callable function and a collection or string - for key, val in rej.items(): - if isinstance(val, (list, tuple)): - if callable(val[0]): - continue - elif val[0] is not None and val[0] >= 0: - continue - else: - raise ValueError( - "%s criteria must be a number >= 0 or a valid" - ' callable, not "%s"' % (kind, val) + for rej, kind in zip((reject, flat), ("Rejection", "Flat")): + for key, val in rej.items(): + name = f"{kind} dict value for {key}" + if isinstance(val, (list, tuple)): + _validate_type( + val[0], ("numeric", "callable"), + val[0], "float, int, or callable" ) - if isinstance(val[1], (list, tuple, str)): + if ( + isinstance(val[0], (int, float)) and + (val[0] is None or val[0] < 0) + ): + raise ValueError( + """If using numerical %s criteria, the value + must be >= 0 Not '%s'.""" % (kind, val[0]) + ) + _validate_type(val[1], ("str", "array-like"), val[1]) continue - else: + _validate_type(val, "numeric", name, extra="or callable") + if val is None or val < 0: raise ValueError( - "%s reason must be a collection or string, " - "not %s" % (kind, type(val[1])) + """If using numerical %s criteria, the value + must be >= 0 Not '%s'.""" % (kind, val) ) - else: - raise ValueError( - """The dictionary elements in %s must be in the - form of a collection that contains a callable or value - in the first element and a collection or string - in the second element""" - % rej - ) # now check to see if our rejection and flat are getting more # restrictive @@ -1565,6 +1561,16 @@ def drop(self, indices, reason="USER", verbose=None): if indices.ndim > 1: raise ValueError("indices must be a scalar or a 1-d array") + # Check if indices and reasons are of the same length + # if using collection to drop epochs + if (isinstance(reason, (list, tuple))): + if len(indices) != len(reason): + raise ValueError( + "If using a list or tuple as the reason, " + "indices and reasons must be of the same length, got " + f"{len(indices)} and {len(reason)}" + ) + if indices.dtype == bool: indices = np.where(indices)[0] @@ -1767,7 +1773,7 @@ def _get_data( is_good, bad_tuple = self._is_good_epoch(epoch, verbose=verbose) if not is_good: assert isinstance(bad_tuple, tuple) - assert all(isinstance(x, str) for x in bad_tuple) + assert all(isinstance(x, (str, tuple)) for x in bad_tuple) drop_log[sel] = drop_log[sel] + bad_tuple continue good_idx.append(idx) @@ -3715,7 +3721,10 @@ def _is_good( for refl, f, t in zip([reject, flat], [np.greater, np.less], ["", "flat"]): if refl is not None: for key, refl in refl.items(): - criterion = refl[0] + if isinstance(refl, (tuple, list)): + criterion = refl[0] + else: + criterion = refl idx = channel_type_idx[key] name = key.upper() if len(idx) > 0: @@ -3734,17 +3743,26 @@ def _is_good( )[0] if len(idx_deltas) > 0: - bad_names = [ch_names[idx[i]] for i in idx_deltas] - if not has_printed: - logger.info( - " Rejecting %s epoch based on %s : " - "%s" % (t, name, bad_names) - ) - has_printed = True - if not full_report: - return False + if isinstance(refl, (tuple, list)): + reasons = list(refl[1]) + for idx, reason in enumerate(reasons): + if isinstance(reason, str): + reasons[idx] = (reason,) + if isinstance(reason, list): + reasons[idx] = tuple(reason) + bad_tuple += tuple(reasons) else: - bad_tuple += tuple(bad_names) + bad_names = [ch_names[idx[i]] for i in idx_deltas] + if not has_printed: + logger.info( + " Rejecting %s epoch based on %s : " + "%s" % (t, name, bad_names) + ) + has_printed = True + if not full_report: + return False + else: + bad_tuple += tuple(bad_names) if not full_report: return True diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 9f9d24d31a7..8d9c13afdd6 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -488,10 +488,11 @@ def test_average_movements(): def _assert_drop_log_types(drop_log): __tracebackhide__ = True - assert isinstance(drop_log, tuple), "drop_log should be tuple" + assert isinstance(drop_log, (tuple, list)), """drop_log should be tuple + or list""" assert all( - isinstance(log, tuple) for log in drop_log - ), "drop_log[ii] should be tuple" + isinstance(log, (tuple, list)) for log in drop_log + ), "drop_log[ii] should be tuple or list" assert all( isinstance(s, str) for log in drop_log for s in log ), "drop_log[ii][jj] should be str" @@ -549,7 +550,7 @@ def test_reject(): preload=False, reject=dict(eeg=np.inf), ) - for val in (None, -1): # protect against older MNE-C types + for val in (-1, (-1, 'Hi')): # protect against older MNE-C types for kwarg in ("reject", "flat"): pytest.raises( ValueError, @@ -563,6 +564,21 @@ def test_reject(): preload=False, **{kwarg: dict(grad=val)}, ) + bad_types = ['Hi', ('Hi' 'Hi'), (1, 1)] + for val in bad_types: # protect against bad types + for kwarg in ("reject", "flat"): + pytest.raises( + TypeError, + Epochs, + raw, + events, + event_id, + tmin, + tmax, + picks=picks_meg, + preload=False, + **{kwarg: dict(grad=val)}, + ) pytest.raises( KeyError, Epochs, @@ -2175,7 +2191,7 @@ def test_callable_reject(): tmax=1, baseline=None, reject=dict( - eeg=lambda x: True if (np.median(x, axis=1) > 1e-3).any() else False + eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), "eeg median") ), preload=True, ) @@ -2187,7 +2203,7 @@ def test_callable_reject(): tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: True if (np.max(x, axis=1) > 1).any() else False), + reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1).any(), "eeg max")), preload=True, ) assert epochs.drop_log[0] != () @@ -2195,7 +2211,7 @@ def test_callable_reject(): def reject_criteria(x): max_condition = np.max(x, axis=1) > 1e-2 median_condition = np.median(x, axis=1) > 1e-4 - return True if max_condition.any() or median_condition.any() else False + (max_condition.any() or median_condition.any(), "eeg max or median") epochs = mne.Epochs( edit_raw, @@ -2206,6 +2222,7 @@ def reject_criteria(x): reject=dict(eeg=reject_criteria), preload=True, ) + print(epochs.drop_log) assert epochs.drop_log[0] != () and epochs.drop_log[2] != () @@ -3262,6 +3279,22 @@ def test_drop_epochs(): assert_array_equal(events[epochs[3:].selection], events1[[5, 6]]) assert_array_equal(events[epochs["1"].selection], events1[[0, 1, 3, 5, 6]]) + # Test using tuple to drop epochs + raw, events, picks = _get_data() + epochs_tuple = Epochs( + raw, events, event_id, + tmin, tmax, picks=picks, preload=True + ) + selection_tuple = epochs_tuple.selection.copy() + epochs_tuple.drop((2, 3, 4), reason=([['list'], 'string', ('tuple',)])) + n_events = len(epochs.events) + assert_equal( + [epochs_tuple.drop_log[k] for k in selection_tuple[[2, 3, 4]]], [ + ['list'], ["string"], ['tuple']] + ) + + + @pytest.mark.parametrize("preload", (True, False)) def test_drop_epochs_mult(preload): diff --git a/mne/utils/mixin.py b/mne/utils/mixin.py index 8d7320fbb8d..c7cfa06297e 100644 --- a/mne/utils/mixin.py +++ b/mne/utils/mixin.py @@ -210,14 +210,22 @@ def _getitem( drop_log = list(inst.drop_log) if reason is not None: # Used for multiple reasons - if isinstance(reason, (list, tuple)): + if isinstance(reason, tuple): + reason = list(reason) + + if isinstance(reason, list): for i, idx in enumerate( np.setdiff1d(inst.selection, key_selection) ): - drop_log[idx] = reason[i] + r = reason[i] + if isinstance(r, str): + r = (r,) + if isinstance(r, list): + r = tuple(r) + drop_log[idx] = r else: for idx in np.setdiff1d(inst.selection, key_selection): - drop_log[idx] = reason + drop_log[idx] = (reason,) inst.drop_log = tuple(drop_log) inst.selection = key_selection del drop_log