Skip to content

Commit

Permalink
review
Browse files Browse the repository at this point in the history
  • Loading branch information
alexrockhill committed Dec 22, 2023
1 parent f9401bd commit cad31f4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 34 deletions.
71 changes: 38 additions & 33 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,10 +488,7 @@ def __init__(
if events is not None: # RtEpochs can have events=None
for key, val in self.event_id.items():
if val not in events[:, 2]:
msg = "No matching events found for %s " "(event id %i)" % (
key,
val,
)
msg = f"No matching events found for {key} (event id {val})"
_on_missing(on_missing, msg)

# ensure metadata matches original events size
Expand Down Expand Up @@ -3105,6 +3102,40 @@ def _ensure_list(x):
return metadata, events, event_id


def _events_from_annotations(raw, events, event_id, annotations, on_missing):
"""Generate events and event_ids from annotations."""
events, event_id_tmp = events_from_annotations(raw)
if events.size == 0:
raise RuntimeError(
"No usable annotations found in the raw object. "
"Either `events` must be provided or the raw "
"object must have annotations to construct epochs"
)
if any(raw.annotations.duration > 0):
logger.info(
"Ignoring annotation durations and creating fixed-duration epochs "
"around annotation onsets."
)
if event_id is None:
event_id = event_id_tmp
# if event_id is the names of events, map to events integers
if isinstance(event_id, str):
event_id = [event_id]
if isinstance(event_id, (list, tuple, set)):
if not set(event_id).issubset(set(event_id_tmp)):
msg = (
"No matching annotations found for event_id(s) "
f"{set(event_id) - set(event_id_tmp)}"
)
_on_missing(on_missing, msg)
# remove extras if on_missing not error
event_id = set(event_id) & set(event_id_tmp)
event_id = {my_id: event_id_tmp[my_id] for my_id in event_id}
# remove any non-selected annotations
annotations.delete(~np.isin(raw.annotations.description, list(event_id)))
return events, event_id, annotations


@fill_doc
class Epochs(BaseEpochs):
"""Epochs extracted from a Raw instance.
Expand Down Expand Up @@ -3262,35 +3293,9 @@ def __init__(

# get events from annotations if no events given
if events is None:
events, event_id_tmp = events_from_annotations(raw)
if events.size == 0:
raise RuntimeError(
"No usable annotations found in the raw object. "
"Either `events` must be provided or the raw "
"object must have annotations to construct epochs"
)
if any(raw.annotations.duration > 0):
logger.info(
"Ignoring annotation durations and creating fixed-duration epochs "
"around annotation onsets."
)
if event_id is None:
event_id = event_id_tmp
# if event_id is the names of events, map to events integers
if isinstance(event_id, str):
event_id = [event_id]
if isinstance(event_id, (list, tuple, set)):
if set(event_id).issubset(set(event_id_tmp)):
event_id = {my_id: event_id_tmp[my_id] for my_id in event_id}
# remove any non-selected annotations
annotations.delete(
~np.isin(raw.annotations.description, list(event_id))
)
else:
raise RuntimeError(
f"event_id(s) {set(event_id) - set(event_id_tmp)} "
"not found in annotations"
)
events, event_id, annotations = _events_from_annotations(
raw, events, event_id, annotations, on_missing
)

# call BaseEpochs constructor
super(Epochs, self).__init__(
Expand Down
6 changes: 5 additions & 1 deletion mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,8 +1004,12 @@ def test_epochs_from_annotations():
events, raw.info["sfreq"], first_samp=raw.first_samp
)
)
with pytest.raises(RuntimeError, match="not found in annotations"):
# test on_missing
with pytest.raises(ValueError, match="No matching annotations"):
Epochs(raw, event_id="foo")
# test on_missing warn
with pytest.warns(match="No matching annotations"):
Epochs(raw, event_id=["1", "foo"], on_missing="warn")


def test_epochs_hash():
Expand Down

0 comments on commit cad31f4

Please sign in to comment.