Skip to content

Commit

Permalink
BUG: Fix bug with regress_artifact picking
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Jan 25, 2024
1 parent 03d78f4 commit 4a5b2db
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 12 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug where :func:`mne.preprocessing.regress_artifact` projection check was not specific to the channels being processed, by `Eric Larson`_.
12 changes: 11 additions & 1 deletion mne/_fiff/pick.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,8 @@ def pick_info(info, sel=(), copy=True, verbose=None):
return info
elif len(sel) == 0:
raise ValueError("No channels match the selection.")
n_unique = len(np.unique(np.arange(len(info["ch_names"]))[sel]))
ch_set = set(info["ch_names"][k] for k in sel)
n_unique = len(ch_set)
if n_unique != len(sel):
raise ValueError(
"Found %d / %d unique names, sel is not unique" % (n_unique, len(sel))
Expand Down Expand Up @@ -687,6 +688,15 @@ def pick_info(info, sel=(), copy=True, verbose=None):
if info.get("custom_ref_applied", False) and not _electrode_types(info):
with info._unlock():
info["custom_ref_applied"] = FIFF.FIFFV_MNE_CUSTOM_REF_OFF
# remove unused projectors
if info.get("projs", []):
projs = list()
for p in info["projs"]:
if any(ch_name in ch_set for ch_name in p["data"]["col_names"]):
projs.append(p)
if len(projs) != len(info["projs"]):
with info._unlock():
info["projs"] = projs
info._check_consistency()

return info
Expand Down
6 changes: 6 additions & 0 deletions mne/_fiff/tests/test_pick.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,11 +558,17 @@ def test_clean_info_bads():
# simulate the bad channels
raw.info["bads"] = eeg_bad_ch + meg_bad_ch

assert len(raw.info["projs"]) == 3
raw.set_eeg_reference(projection=True)
assert len(raw.info["projs"]) == 4

# simulate the call to pick_info excluding the bad eeg channels
info_eeg = pick_info(raw.info, picks_eeg)
assert len(info_eeg["projs"]) == 1

# simulate the call to pick_info excluding the bad meg channels
info_meg = pick_info(raw.info, picks_meg)
assert len(info_meg["projs"]) == 3

assert info_eeg["bads"] == eeg_bad_ch
assert info_meg["bads"] == meg_bad_ch
Expand Down
24 changes: 13 additions & 11 deletions mne/preprocessing/_regress.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from .._fiff.pick import _picks_to_idx
from .._fiff.pick import _picks_to_idx, pick_info
from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT
from ..epochs import BaseEpochs
from ..evoked import Evoked
Expand Down Expand Up @@ -178,9 +178,7 @@ def fit(self, inst):
reference (see :func:`mne.set_eeg_reference`) before performing EOG
regression.
"""
self._check_inst(inst)
picks = _picks_to_idx(inst.info, self.picks, none="data", exclude=self.exclude)
picks_artifact = _picks_to_idx(inst.info, self.picks_artifact)
picks, picks_artifact = self._check_inst(inst)

# Calculate regression coefficients. Add a row of ones to also fit the
# intercept.
Expand Down Expand Up @@ -232,9 +230,7 @@ def apply(self, inst, copy=True):
"""
if copy:
inst = inst.copy()
self._check_inst(inst)
picks = _picks_to_idx(inst.info, self.picks, none="data", exclude=self.exclude)
picks_artifact = _picks_to_idx(inst.info, self.picks_artifact)
picks, picks_artifact = self._check_inst(inst)

# Check that the channels are compatible with the regression weights.
ref_picks = _picks_to_idx(
Expand Down Expand Up @@ -324,19 +320,25 @@ def _check_inst(self, inst):
_validate_type(
inst, (BaseRaw, BaseEpochs, Evoked), "inst", "Raw, Epochs, Evoked"
)
if _needs_eeg_average_ref_proj(inst.info):
picks = _picks_to_idx(inst.info, self.picks, none="data", exclude=self.exclude)
picks_artifact = _picks_to_idx(inst.info, self.picks_artifact)
all_picks = np.unique(np.concatenate([picks, picks_artifact]))
use_info = pick_info(inst.info, all_picks)
del all_picks
if _needs_eeg_average_ref_proj(use_info):
raise RuntimeError(
"No reference for the EEG channels has been "
"set. Use inst.set_eeg_reference() to do so."
"No average reference for the EEG channels has been "
"set. Use inst.set_eeg_reference(projection=True) to do so."
)
if self.proj and not inst.proj:
inst.apply_proj()
if not inst.proj and len(inst.info.get("projs", [])) > 0:
if not inst.proj and len(use_info.get("projs", [])) > 0:
raise RuntimeError(
"Projections need to be applied before "
"regression can be performed. Use the "
".apply_proj() method to do so."
)
return picks, picks_artifact

def __repr__(self):
"""Produce a string representation of this object."""
Expand Down
13 changes: 13 additions & 0 deletions mne/preprocessing/tests/test_regress.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@ def test_regress_artifact():
epochs, betas = regress_artifact(epochs, picks="eog", picks_artifact="eog")
assert np.ptp(epochs.get_data("eog")) < 1e-15 # constant value
assert_allclose(betas, 1)
# proj should only be required of channels being processed
raw = read_raw_fif(raw_fname).crop(0, 1).load_data()
raw.del_proj()
raw.set_eeg_reference(projection=True)
model = EOGRegression(proj=False, picks="meg", picks_artifact="eog")
model.fit(raw)
model.apply(raw)
model = EOGRegression(proj=False, picks="eeg", picks_artifact="eog")
with pytest.raises(RuntimeError, match="Projections need to be applied"):
model.fit(raw)
raw.del_proj()
with pytest.raises(RuntimeError, match="No average reference for the EEG"):
model.fit(raw)


@testing.requires_testing_data
Expand Down

0 comments on commit 4a5b2db

Please sign in to comment.