diff --git a/mne/_version.py b/mne/_version.py new file mode 100644 index 00000000000..c741fa16728 --- /dev/null +++ b/mne/_version.py @@ -0,0 +1,16 @@ +# file generated by setuptools_scm +# don't change, don't track in version control +TYPE_CHECKING = False +if TYPE_CHECKING: + from typing import Tuple, Union + VERSION_TUPLE = Tuple[Union[int, str], ...] +else: + VERSION_TUPLE = object + +version: str +__version__: str +__version_tuple__: VERSION_TUPLE +version_tuple: VERSION_TUPLE + +__version__ = version = '1.6.0.dev139+gfdaeb8620' +__version_tuple__ = version_tuple = (1, 6, 0, 'dev139', 'gfdaeb8620') diff --git a/mne/epochs.py b/mne/epochs.py index fdba9d23cc6..e60e235f5cd 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -700,9 +700,8 @@ def _check_consistency(self): assert hasattr(self, "_times_readonly") assert not self.times.flags["WRITEABLE"] assert isinstance(self.drop_log, tuple) - print("self.drop_log", self.drop_log) assert all(isinstance(log, tuple) for log in self.drop_log) - assert all(isinstance(s, (str,tuple)) for log in self.drop_log for s in log) + assert all(isinstance(s, str) for log in self.drop_log for s in log) def reset_drop_log_selection(self): """Reset the drop_log and selection entries. @@ -820,20 +819,7 @@ def _reject_setup(self, reject, flat): for rej, kind in zip((reject, flat), ("Rejection", "Flat")): for key, val in rej.items(): name = f"{kind} dict value for {key}" - if isinstance(val, (list, tuple)): - _validate_type( - val[0], ("numeric", "callable"), - val[0], "float, int, or callable" - ) - if ( - isinstance(val[0], (int, float)) and - (val[0] is None or val[0] < 0) - ): - raise ValueError( - """If using numerical %s criteria, the value - must be >= 0 Not '%s'.""" % (kind, val[0]) - ) - _validate_type(val[1], ("str", "array-like"), val[1]) + if callable(val): continue _validate_type(val, "numeric", name, extra="or callable") if val is None or val < 0: @@ -3721,20 +3707,30 @@ def _is_good( for refl, f, t in zip([reject, flat], [np.greater, np.less], ["", "flat"]): if refl is not None: for key, refl in refl.items(): - if isinstance(refl, (tuple, list)): - criterion = refl[0] - else: - criterion = refl + criterion = refl idx = channel_type_idx[key] name = key.upper() if len(idx) > 0: e_idx = e[idx] checkable_idx = checkable[idx] - # Check if criterion is a function and apply it if callable(criterion): + result = criterion(e_idx) + _validate_type(result, tuple, result, "tuple") + if len(result) != 2: + raise TypeError( + "Function criterion must return a " + "tuple of length 2" + ) + cri_truth, reasons = result + _validate_type(cri_truth, (bool, np.bool_), + cri_truth, "bool") + _validate_type( + reasons, (str, list, tuple), + reasons, "str, list, or tuple" + ) idx_deltas = np.where( - np.logical_and(criterion(e_idx), checkable_idx) + np.logical_and(cri_truth, checkable_idx) )[0] else: deltas = np.max(e_idx, axis=1) - np.min(e_idx, axis=1) @@ -3743,14 +3739,16 @@ def _is_good( )[0] if len(idx_deltas) > 0: - if isinstance(refl, (tuple, list)): - reasons = list(refl[1]) - for idx, reason in enumerate(reasons): - if isinstance(reason, str): - reasons[idx] = (reason,) - if isinstance(reason, list): - reasons[idx] = tuple(reason) - bad_tuple += tuple(reasons) + # Check to verify that refl is a callable that returns + # (bool, reason). Reason must be a str/list/tuple. + # If using tuple + if callable(refl): + if isinstance(reasons, (tuple, list)): + for idx, reason in enumerate(reasons): + _validate_type(reason, str, reason, "str") + bad_tuple += tuple(reasons) + if isinstance(reasons, str): + bad_tuple += (reasons,) else: bad_names = [ch_names[idx[i]] for i in idx_deltas] if not has_printed: diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 8d9c13afdd6..f71eabda717 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -550,7 +550,7 @@ def test_reject(): preload=False, reject=dict(eeg=np.inf), ) - for val in (-1, (-1, 'Hi')): # protect against older MNE-C types + for val in (-1, -2): # protect against older MNE-C types for kwarg in ("reject", "flat"): pytest.raises( ValueError, @@ -564,7 +564,17 @@ def test_reject(): preload=False, **{kwarg: dict(grad=val)}, ) - bad_types = ['Hi', ('Hi' 'Hi'), (1, 1)] + + def my_reject_1(epoch_data): + bad_idxs = np.where(np.percentile(epoch_data, axis=1) > 1e-35) + return len(bad_idxs) > 0 + + def my_reject_2(epoch_data): + bad_idxs = np.where(np.percentile(epoch_data, axis=1) > 1e-35) + reasons = tuple(epochs.ch_name[bad_idx] for bad_idx in bad_idxs) + return len(bad_idxs), reasons + + bad_types = [my_reject_1, my_reject_2, ('Hi' 'Hi'), (1, 1)] for val in bad_types: # protect against bad types for kwarg in ("reject", "flat"): pytest.raises( @@ -576,7 +586,7 @@ def test_reject(): tmin, tmax, picks=picks_meg, - preload=False, + preload=True, **{kwarg: dict(grad=val)}, ) pytest.raises( @@ -2211,7 +2221,10 @@ def test_callable_reject(): def reject_criteria(x): max_condition = np.max(x, axis=1) > 1e-2 median_condition = np.median(x, axis=1) > 1e-4 - (max_condition.any() or median_condition.any(), "eeg max or median") + return ( + (max_condition.any() or median_condition.any()), + "eeg max or median" + ) epochs = mne.Epochs( edit_raw, diff --git a/mne/utils/docs.py b/mne/utils/docs.py index f64dbf4e66e..79e6bb945e2 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -3291,9 +3291,10 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Custom rejection criteria can be also be used by passing a callable, e.g., to check for 99th percentile of absolute values of any channel - across time being bigger than 1mV:: + across time being bigger than 1mV. The callable must return a good, reason tuple. + Where good must be bool and reason must be str, list, or tuple where each entry is a str.:: - reject = dict(eeg=lambda x: (np.percentile(np.abs(x), 99, axis=1) > 1e-3).any()) + reject = dict(eeg=lambda x: ((np.percentile(np.abs(x), 99, axis=1) > 1e-3).any(), "> 1mV somewhere")) .. note:: If rejection is based on a signal **difference** calculated for each channel separately, applying baseline diff --git a/tutorials/preprocessing/20_rejecting_bad_data.py b/tutorials/preprocessing/20_rejecting_bad_data.py index 1690371b930..27f1085f713 100644 --- a/tutorials/preprocessing/20_rejecting_bad_data.py +++ b/tutorials/preprocessing/20_rejecting_bad_data.py @@ -366,7 +366,9 @@ # spike in amplitude and one large increase in amplitude. # Let's try to reject the epoch containing the spike in amplitude based on the -# maximum amplitude of the first channel. +# maximum amplitude of the first channel. Please note that the callable in +# ``reject`` must return a (good, reason) tuple. Where the good must be bool +# and reason must be a str, list, or tuple where each entry is a str. epochs = mne.Epochs( edit_raw, @@ -374,7 +376,7 @@ tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: (np.max(x, axis=1) > 1e-2).any()), + reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1e-2).any(), "max amp")), preload=True, ) epochs.plot(scalings=dict(eeg=50e-5)) @@ -395,7 +397,9 @@ tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: (np.median(x, axis=1) > 1e-4).any()), + reject=dict( + eeg=lambda x: ((np.median(x, axis=1) > 1e-4).any(), "median amp") + ), preload=True, ) epochs.plot(scalings=dict(eeg=50e-5)) @@ -409,7 +413,10 @@ def reject_criteria(x): max_condition = np.max(x, axis=1) > 1e-2 median_condition = np.median(x, axis=1) > 1e-4 - return True if max_condition.any() or median_condition.any() else False + return ( + (max_condition.any() or median_condition.any()), + ["max amp", "median amp"] + ) epochs = mne.Epochs(