Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 9, 2024
1 parent cf1facf commit e98bee2
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 42 deletions.
5 changes: 3 additions & 2 deletions mne/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple, Union

VERSION_TUPLE = Tuple[Union[int, str], ...]
else:
VERSION_TUPLE = object
Expand All @@ -12,5 +13,5 @@
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE

__version__ = version = '1.6.0.dev139+gfdaeb8620'
__version_tuple__ = version_tuple = (1, 6, 0, 'dev139', 'gfdaeb8620')
__version__ = version = "1.6.0.dev139+gfdaeb8620"
__version_tuple__ = version_tuple = (1, 6, 0, "dev139", "gfdaeb8620")
28 changes: 12 additions & 16 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,8 +799,7 @@ def _reject_setup(self, reject, flat):
)
bads = set(rej.keys()) - set(idx.keys())
if len(bads) > 0:
raise KeyError(
"Unknown channel types found in %s: %s" % (kind, bads))
raise KeyError("Unknown channel types found in %s: %s" % (kind, bads))

for key in idx.keys():
# don't throw an error if rejection/flat would do nothing
Expand All @@ -815,7 +814,7 @@ def _reject_setup(self, reject, flat):
"%s." % (key.upper(), key.upper())
)

# check for invalid values
# check for invalid values
for rej, kind in zip((reject, flat), ("Rejection", "Flat")):
for key, val in rej.items():
name = f"{kind} dict value for {key}"
Expand All @@ -825,7 +824,8 @@ def _reject_setup(self, reject, flat):
if val is None or val < 0:
raise ValueError(
"""If using numerical %s criteria, the value
must be >= 0 Not '%s'.""" % (kind, val)
must be >= 0 Not '%s'."""
% (kind, val)
)

# now check to see if our rejection and flat are getting more
Expand Down Expand Up @@ -1547,17 +1547,16 @@ def drop(self, indices, reason="USER", verbose=None):

if indices.ndim > 1:
raise ValueError("indices must be a scalar or a 1-d array")
# Check if indices and reasons are of the same length
# Check if indices and reasons are of the same length
# if using collection to drop epochs
if (isinstance(reason, (list, tuple))):
if isinstance(reason, (list, tuple)):
if len(indices) != len(reason):
raise ValueError(
"If using a list or tuple as the reason, "
"indices and reasons must be of the same length, got "
f"{len(indices)} and {len(reason)}"
)


if indices.dtype == bool:
indices = np.where(indices)[0]
try_idx = np.where(indices < 0, indices + len(self.events), indices)
Expand Down Expand Up @@ -3719,19 +3718,16 @@ def _is_good(
_validate_type(result, tuple, result, "tuple")
if len(result) != 2:
raise TypeError(
"Function criterion must return a "
"tuple of length 2"
"Function criterion must return a " "tuple of length 2"
)
cri_truth, reasons = result
_validate_type(cri_truth, (bool, np.bool_),
cri_truth, "bool")
_validate_type(cri_truth, (bool, np.bool_), cri_truth, "bool")
_validate_type(
reasons, (str, list, tuple),
reasons, "str, list, or tuple"
reasons, (str, list, tuple), reasons, "str, list, or tuple"
)
idx_deltas = np.where(
np.logical_and(cri_truth, checkable_idx)
)[0]
idx_deltas = np.where(np.logical_and(cri_truth, checkable_idx))[
0
]
else:
deltas = np.max(e_idx, axis=1) - np.min(e_idx, axis=1)
idx_deltas = np.where(
Expand Down
24 changes: 7 additions & 17 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def my_reject_2(epoch_data):
reasons = tuple(epochs.ch_name[bad_idx] for bad_idx in bad_idxs)
return len(bad_idxs), reasons

bad_types = [my_reject_1, my_reject_2, ('Hi' 'Hi'), (1, 1)]
bad_types = [my_reject_1, my_reject_2, ("Hi" "Hi"), (1, 1)]
for val in bad_types: # protect against bad types
for kwarg in ("reject", "flat"):
pytest.raises(
Expand Down Expand Up @@ -2200,9 +2200,7 @@ def test_callable_reject():
tmin=0,
tmax=1,
baseline=None,
reject=dict(
eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), "eeg median")
),
reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), "eeg median")),
preload=True,
)
assert epochs.drop_log[2] != ()
Expand All @@ -2221,10 +2219,7 @@ def test_callable_reject():
def reject_criteria(x):
max_condition = np.max(x, axis=1) > 1e-2
median_condition = np.median(x, axis=1) > 1e-4
return (
(max_condition.any() or median_condition.any()),
"eeg max or median"
)
return ((max_condition.any() or median_condition.any()), "eeg max or median")

epochs = mne.Epochs(
edit_raw,
Expand Down Expand Up @@ -3294,21 +3289,16 @@ def test_drop_epochs():

# Test using tuple to drop epochs
raw, events, picks = _get_data()
epochs_tuple = Epochs(
raw, events, event_id,
tmin, tmax, picks=picks, preload=True
)
epochs_tuple = Epochs(raw, events, event_id, tmin, tmax, picks=picks, preload=True)
selection_tuple = epochs_tuple.selection.copy()
epochs_tuple.drop((2, 3, 4), reason=([['list'], 'string', ('tuple',)]))
epochs_tuple.drop((2, 3, 4), reason=([["list"], "string", ("tuple",)]))
n_events = len(epochs.events)
assert_equal(
[epochs_tuple.drop_log[k] for k in selection_tuple[[2, 3, 4]]], [
['list'], ["string"], ['tuple']]
[epochs_tuple.drop_log[k] for k in selection_tuple[[2, 3, 4]]],
[["list"], ["string"], ["tuple"]],
)




@pytest.mark.parametrize("preload", (True, False))
def test_drop_epochs_mult(preload):
"""Test that subselecting epochs or making fewer epochs is similar."""
Expand Down
9 changes: 2 additions & 7 deletions tutorials/preprocessing/20_rejecting_bad_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,7 @@
tmin=0,
tmax=1,
baseline=None,
reject=dict(
eeg=lambda x: ((np.median(x, axis=1) > 1e-4).any(), "median amp")
),
reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-4).any(), "median amp")),
preload=True,
)
epochs.plot(scalings=dict(eeg=50e-5))
Expand All @@ -413,10 +411,7 @@
def reject_criteria(x):
max_condition = np.max(x, axis=1) > 1e-2
median_condition = np.median(x, axis=1) > 1e-4
return (
(max_condition.any() or median_condition.any()),
["max amp", "median amp"]
)
return ((max_condition.any() or median_condition.any()), ["max amp", "median amp"])


epochs = mne.Epochs(
Expand Down

0 comments on commit e98bee2

Please sign in to comment.