Skip to content

Commit

Permalink
return callable/reasons
Browse files Browse the repository at this point in the history
  • Loading branch information
withmywoessner committed Jan 9, 2024
1 parent bce6486 commit 8401c92
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 54 deletions.
106 changes: 62 additions & 44 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,8 +700,9 @@ def _check_consistency(self):
assert hasattr(self, "_times_readonly")
assert not self.times.flags["WRITEABLE"]
assert isinstance(self.drop_log, tuple)
print("self.drop_log", self.drop_log)
assert all(isinstance(log, tuple) for log in self.drop_log)
assert all(isinstance(s, str) for log in self.drop_log for s in log)
assert all(isinstance(s, (str,tuple)) for log in self.drop_log for s in log)

def reset_drop_log_selection(self):
"""Reset the drop_log and selection entries.
Expand Down Expand Up @@ -793,9 +794,14 @@ def _reject_setup(self, reject, flat):
reject = deepcopy(reject) if reject is not None else dict()
flat = deepcopy(flat) if flat is not None else dict()
for rej, kind in zip((reject, flat), ("reject", "flat")):
if not isinstance(rej, dict):
raise TypeError(
"reject and flat must be dict or None, not %s" % type(rej)
)
bads = set(rej.keys()) - set(idx.keys())
if len(bads) > 0:
raise KeyError("Unknown channel types found in %s: %s" % (kind, bads))
raise KeyError(
"Unknown channel types found in %s: %s" % (kind, bads))

for key in idx.keys():
# don't throw an error if rejection/flat would do nothing
Expand All @@ -811,40 +817,30 @@ def _reject_setup(self, reject, flat):
)

# check for invalid values
for rej, kind in zip((reject, flat), ("Rejection", "Flat")):
if not isinstance(rej, dict):
raise TypeError(
"reject and flat must be dict or None, not %s" % type(rej)
)

# Check if each reject/flat dict is a tuple that contains a
# callable function and a collection or string
for key, val in rej.items():
if isinstance(val, (list, tuple)):
if callable(val[0]):
continue
elif val[0] is not None and val[0] >= 0:
continue
else:
raise ValueError(
"%s criteria must be a number >= 0 or a valid"
' callable, not "%s"' % (kind, val)
for rej, kind in zip((reject, flat), ("Rejection", "Flat")):
for key, val in rej.items():
name = f"{kind} dict value for {key}"
if isinstance(val, (list, tuple)):
_validate_type(
val[0], ("numeric", "callable"),
val[0], "float, int, or callable"
)
if isinstance(val[1], (list, tuple, str)):
if (
isinstance(val[0], (int, float)) and
(val[0] is None or val[0] < 0)
):
raise ValueError(
"""If using numerical %s criteria, the value
must be >= 0 Not '%s'.""" % (kind, val[0])
)
_validate_type(val[1], ("str", "array-like"), val[1])
continue
else:
_validate_type(val, "numeric", name, extra="or callable")
if val is None or val < 0:
raise ValueError(
"%s reason must be a collection or string, "
"not %s" % (kind, type(val[1]))
"""If using numerical %s criteria, the value
must be >= 0 Not '%s'.""" % (kind, val)
)
else:
raise ValueError(
"""The dictionary elements in %s must be in the
form of a collection that contains a callable or value
in the first element and a collection or string
in the second element"""
% rej
)

# now check to see if our rejection and flat are getting more
# restrictive
Expand Down Expand Up @@ -1565,6 +1561,16 @@ def drop(self, indices, reason="USER", verbose=None):

if indices.ndim > 1:
raise ValueError("indices must be a scalar or a 1-d array")
# Check if indices and reasons are of the same length
# if using collection to drop epochs
if (isinstance(reason, (list, tuple))):
if len(indices) != len(reason):
raise ValueError(
"If using a list or tuple as the reason, "
"indices and reasons must be of the same length, got "
f"{len(indices)} and {len(reason)}"
)


if indices.dtype == bool:
indices = np.where(indices)[0]
Expand Down Expand Up @@ -1767,7 +1773,7 @@ def _get_data(
is_good, bad_tuple = self._is_good_epoch(epoch, verbose=verbose)
if not is_good:
assert isinstance(bad_tuple, tuple)
assert all(isinstance(x, str) for x in bad_tuple)
assert all(isinstance(x, (str, tuple)) for x in bad_tuple)
drop_log[sel] = drop_log[sel] + bad_tuple
continue
good_idx.append(idx)
Expand Down Expand Up @@ -3715,7 +3721,10 @@ def _is_good(
for refl, f, t in zip([reject, flat], [np.greater, np.less], ["", "flat"]):
if refl is not None:
for key, refl in refl.items():
criterion = refl[0]
if isinstance(refl, (tuple, list)):
criterion = refl[0]
else:
criterion = refl
idx = channel_type_idx[key]
name = key.upper()
if len(idx) > 0:
Expand All @@ -3734,17 +3743,26 @@ def _is_good(
)[0]

if len(idx_deltas) > 0:
bad_names = [ch_names[idx[i]] for i in idx_deltas]
if not has_printed:
logger.info(
" Rejecting %s epoch based on %s : "
"%s" % (t, name, bad_names)
)
has_printed = True
if not full_report:
return False
if isinstance(refl, (tuple, list)):
reasons = list(refl[1])
for idx, reason in enumerate(reasons):
if isinstance(reason, str):
reasons[idx] = (reason,)
if isinstance(reason, list):
reasons[idx] = tuple(reason)
bad_tuple += tuple(reasons)
else:
bad_tuple += tuple(bad_names)
bad_names = [ch_names[idx[i]] for i in idx_deltas]
if not has_printed:
logger.info(
" Rejecting %s epoch based on %s : "
"%s" % (t, name, bad_names)
)
has_printed = True
if not full_report:
return False
else:
bad_tuple += tuple(bad_names)

if not full_report:
return True
Expand Down
47 changes: 40 additions & 7 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,10 +488,11 @@ def test_average_movements():

def _assert_drop_log_types(drop_log):
__tracebackhide__ = True
assert isinstance(drop_log, tuple), "drop_log should be tuple"
assert isinstance(drop_log, (tuple, list)), """drop_log should be tuple
or list"""
assert all(
isinstance(log, tuple) for log in drop_log
), "drop_log[ii] should be tuple"
isinstance(log, (tuple, list)) for log in drop_log
), "drop_log[ii] should be tuple or list"
assert all(
isinstance(s, str) for log in drop_log for s in log
), "drop_log[ii][jj] should be str"
Expand Down Expand Up @@ -549,7 +550,7 @@ def test_reject():
preload=False,
reject=dict(eeg=np.inf),
)
for val in (None, -1): # protect against older MNE-C types
for val in (-1, (-1, 'Hi')): # protect against older MNE-C types
for kwarg in ("reject", "flat"):
pytest.raises(
ValueError,
Expand All @@ -563,6 +564,21 @@ def test_reject():
preload=False,
**{kwarg: dict(grad=val)},
)
bad_types = ['Hi', ('Hi' 'Hi'), (1, 1)]
for val in bad_types: # protect against bad types
for kwarg in ("reject", "flat"):
pytest.raises(
TypeError,
Epochs,
raw,
events,
event_id,
tmin,
tmax,
picks=picks_meg,
preload=False,
**{kwarg: dict(grad=val)},
)
pytest.raises(
KeyError,
Epochs,
Expand Down Expand Up @@ -2175,7 +2191,7 @@ def test_callable_reject():
tmax=1,
baseline=None,
reject=dict(
eeg=lambda x: True if (np.median(x, axis=1) > 1e-3).any() else False
eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), "eeg median")
),
preload=True,
)
Expand All @@ -2187,15 +2203,15 @@ def test_callable_reject():
tmin=0,
tmax=1,
baseline=None,
reject=dict(eeg=lambda x: True if (np.max(x, axis=1) > 1).any() else False),
reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1).any(), "eeg max")),
preload=True,
)
assert epochs.drop_log[0] != ()

def reject_criteria(x):
max_condition = np.max(x, axis=1) > 1e-2
median_condition = np.median(x, axis=1) > 1e-4
return True if max_condition.any() or median_condition.any() else False
(max_condition.any() or median_condition.any(), "eeg max or median")

epochs = mne.Epochs(
edit_raw,
Expand All @@ -2206,6 +2222,7 @@ def reject_criteria(x):
reject=dict(eeg=reject_criteria),
preload=True,
)
print(epochs.drop_log)
assert epochs.drop_log[0] != () and epochs.drop_log[2] != ()


Expand Down Expand Up @@ -3262,6 +3279,22 @@ def test_drop_epochs():
assert_array_equal(events[epochs[3:].selection], events1[[5, 6]])
assert_array_equal(events[epochs["1"].selection], events1[[0, 1, 3, 5, 6]])

# Test using tuple to drop epochs
raw, events, picks = _get_data()
epochs_tuple = Epochs(
raw, events, event_id,
tmin, tmax, picks=picks, preload=True
)
selection_tuple = epochs_tuple.selection.copy()
epochs_tuple.drop((2, 3, 4), reason=([['list'], 'string', ('tuple',)]))
n_events = len(epochs.events)
assert_equal(
[epochs_tuple.drop_log[k] for k in selection_tuple[[2, 3, 4]]], [
['list'], ["string"], ['tuple']]
)




@pytest.mark.parametrize("preload", (True, False))
def test_drop_epochs_mult(preload):
Expand Down
14 changes: 11 additions & 3 deletions mne/utils/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,22 @@ def _getitem(
drop_log = list(inst.drop_log)
if reason is not None:
# Used for multiple reasons
if isinstance(reason, (list, tuple)):
if isinstance(reason, tuple):
reason = list(reason)

if isinstance(reason, list):
for i, idx in enumerate(
np.setdiff1d(inst.selection, key_selection)
):
drop_log[idx] = reason[i]
r = reason[i]
if isinstance(r, str):
r = (r,)
if isinstance(r, list):
r = tuple(r)
drop_log[idx] = r
else:
for idx in np.setdiff1d(inst.selection, key_selection):
drop_log[idx] = reason
drop_log[idx] = (reason,)
inst.drop_log = tuple(drop_log)
inst.selection = key_selection
del drop_log
Expand Down

0 comments on commit 8401c92

Please sign in to comment.