Skip to content

Commit

Permalink
allow callables
Browse files Browse the repository at this point in the history
  • Loading branch information
withmywoessner committed Jan 9, 2024
1 parent 8401c92 commit cf1facf
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 40 deletions.
16 changes: 16 additions & 0 deletions mne/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# file generated by setuptools_scm
# don't change, don't track in version control
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple, Union
VERSION_TUPLE = Tuple[Union[int, str], ...]
else:
VERSION_TUPLE = object

version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE

__version__ = version = '1.6.0.dev139+gfdaeb8620'
__version_tuple__ = version_tuple = (1, 6, 0, 'dev139', 'gfdaeb8620')
58 changes: 28 additions & 30 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,9 +700,8 @@ def _check_consistency(self):
assert hasattr(self, "_times_readonly")
assert not self.times.flags["WRITEABLE"]
assert isinstance(self.drop_log, tuple)
print("self.drop_log", self.drop_log)
assert all(isinstance(log, tuple) for log in self.drop_log)
assert all(isinstance(s, (str,tuple)) for log in self.drop_log for s in log)
assert all(isinstance(s, str) for log in self.drop_log for s in log)

def reset_drop_log_selection(self):
"""Reset the drop_log and selection entries.
Expand Down Expand Up @@ -820,20 +819,7 @@ def _reject_setup(self, reject, flat):
for rej, kind in zip((reject, flat), ("Rejection", "Flat")):
for key, val in rej.items():
name = f"{kind} dict value for {key}"
if isinstance(val, (list, tuple)):
_validate_type(
val[0], ("numeric", "callable"),
val[0], "float, int, or callable"
)
if (
isinstance(val[0], (int, float)) and
(val[0] is None or val[0] < 0)
):
raise ValueError(
"""If using numerical %s criteria, the value
must be >= 0 Not '%s'.""" % (kind, val[0])
)
_validate_type(val[1], ("str", "array-like"), val[1])
if callable(val):
continue
_validate_type(val, "numeric", name, extra="or callable")
if val is None or val < 0:
Expand Down Expand Up @@ -3721,20 +3707,30 @@ def _is_good(
for refl, f, t in zip([reject, flat], [np.greater, np.less], ["", "flat"]):
if refl is not None:
for key, refl in refl.items():
if isinstance(refl, (tuple, list)):
criterion = refl[0]
else:
criterion = refl
criterion = refl
idx = channel_type_idx[key]
name = key.upper()
if len(idx) > 0:
e_idx = e[idx]
checkable_idx = checkable[idx]

# Check if criterion is a function and apply it
if callable(criterion):
result = criterion(e_idx)
_validate_type(result, tuple, result, "tuple")
if len(result) != 2:
raise TypeError(
"Function criterion must return a "
"tuple of length 2"
)
cri_truth, reasons = result
_validate_type(cri_truth, (bool, np.bool_),
cri_truth, "bool")
_validate_type(
reasons, (str, list, tuple),
reasons, "str, list, or tuple"
)
idx_deltas = np.where(
np.logical_and(criterion(e_idx), checkable_idx)
np.logical_and(cri_truth, checkable_idx)
)[0]
else:
deltas = np.max(e_idx, axis=1) - np.min(e_idx, axis=1)
Expand All @@ -3743,14 +3739,16 @@ def _is_good(
)[0]

if len(idx_deltas) > 0:
if isinstance(refl, (tuple, list)):
reasons = list(refl[1])
for idx, reason in enumerate(reasons):
if isinstance(reason, str):
reasons[idx] = (reason,)
if isinstance(reason, list):
reasons[idx] = tuple(reason)
bad_tuple += tuple(reasons)
# Check to verify that refl is a callable that returns
# (bool, reason). Reason must be a str/list/tuple.
# If using tuple
if callable(refl):
if isinstance(reasons, (tuple, list)):
for idx, reason in enumerate(reasons):
_validate_type(reason, str, reason, "str")
bad_tuple += tuple(reasons)
if isinstance(reasons, str):
bad_tuple += (reasons,)
else:
bad_names = [ch_names[idx[i]] for i in idx_deltas]
if not has_printed:
Expand Down
21 changes: 17 additions & 4 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ def test_reject():
preload=False,
reject=dict(eeg=np.inf),
)
for val in (-1, (-1, 'Hi')): # protect against older MNE-C types
for val in (-1, -2): # protect against older MNE-C types
for kwarg in ("reject", "flat"):
pytest.raises(
ValueError,
Expand All @@ -564,7 +564,17 @@ def test_reject():
preload=False,
**{kwarg: dict(grad=val)},
)
bad_types = ['Hi', ('Hi' 'Hi'), (1, 1)]

def my_reject_1(epoch_data):
bad_idxs = np.where(np.percentile(epoch_data, axis=1) > 1e-35)
return len(bad_idxs) > 0

def my_reject_2(epoch_data):
bad_idxs = np.where(np.percentile(epoch_data, axis=1) > 1e-35)
reasons = tuple(epochs.ch_name[bad_idx] for bad_idx in bad_idxs)
return len(bad_idxs), reasons

bad_types = [my_reject_1, my_reject_2, ('Hi' 'Hi'), (1, 1)]
for val in bad_types: # protect against bad types
for kwarg in ("reject", "flat"):
pytest.raises(
Expand All @@ -576,7 +586,7 @@ def test_reject():
tmin,
tmax,
picks=picks_meg,
preload=False,
preload=True,
**{kwarg: dict(grad=val)},
)
pytest.raises(
Expand Down Expand Up @@ -2211,7 +2221,10 @@ def test_callable_reject():
def reject_criteria(x):
max_condition = np.max(x, axis=1) > 1e-2
median_condition = np.median(x, axis=1) > 1e-4
(max_condition.any() or median_condition.any(), "eeg max or median")
return (
(max_condition.any() or median_condition.any()),
"eeg max or median"
)

epochs = mne.Epochs(
edit_raw,
Expand Down
5 changes: 3 additions & 2 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3291,9 +3291,10 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
Custom rejection criteria can be also be used by passing a callable,
e.g., to check for 99th percentile of absolute values of any channel
across time being bigger than 1mV::
across time being bigger than 1mV. The callable must return a good, reason tuple.
Where good must be bool and reason must be str, list, or tuple where each entry is a str.::
reject = dict(eeg=lambda x: (np.percentile(np.abs(x), 99, axis=1) > 1e-3).any())
reject = dict(eeg=lambda x: ((np.percentile(np.abs(x), 99, axis=1) > 1e-3).any(), "> 1mV somewhere"))
.. note:: If rejection is based on a signal **difference**
calculated for each channel separately, applying baseline
Expand Down
15 changes: 11 additions & 4 deletions tutorials/preprocessing/20_rejecting_bad_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,15 +366,17 @@
# spike in amplitude and one large increase in amplitude.

# Let's try to reject the epoch containing the spike in amplitude based on the
# maximum amplitude of the first channel.
# maximum amplitude of the first channel. Please note that the callable in
# ``reject`` must return a (good, reason) tuple. Where the good must be bool
# and reason must be a str, list, or tuple where each entry is a str.

epochs = mne.Epochs(
edit_raw,
events,
tmin=0,
tmax=1,
baseline=None,
reject=dict(eeg=lambda x: (np.max(x, axis=1) > 1e-2).any()),
reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1e-2).any(), "max amp")),
preload=True,
)
epochs.plot(scalings=dict(eeg=50e-5))
Expand All @@ -395,7 +397,9 @@
tmin=0,
tmax=1,
baseline=None,
reject=dict(eeg=lambda x: (np.median(x, axis=1) > 1e-4).any()),
reject=dict(
eeg=lambda x: ((np.median(x, axis=1) > 1e-4).any(), "median amp")
),
preload=True,
)
epochs.plot(scalings=dict(eeg=50e-5))
Expand All @@ -409,7 +413,10 @@
def reject_criteria(x):
max_condition = np.max(x, axis=1) > 1e-2
median_condition = np.median(x, axis=1) > 1e-4
return True if max_condition.any() or median_condition.any() else False
return (
(max_condition.any() or median_condition.any()),
["max amp", "median amp"]
)


epochs = mne.Epochs(
Expand Down

0 comments on commit cf1facf

Please sign in to comment.