diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 325be7350a6..3cb7b62b06c 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -870,9 +870,12 @@ def interpolate_bads( .. versionadded:: 0.9.0 """ from .interpolation import ( + _interpolate_bads_nan, _interpolate_bads_eeg, _interpolate_bads_meeg, _interpolate_bads_nirs, + _interpolate_bads_ecog, + _interpolate_bads_seeg, ) _check_preload(self, "interpolation") @@ -894,9 +897,11 @@ def interpolate_bads( "eeg": ("spline", "MNE", "nan"), "meg": ("MNE", "nan"), "fnirs": ("nearest", "nan"), + "ecog": ("spline", "nan"), + "seeg": ("spline", "nan"), } for key in method: - _check_option("method[key]", key, ("meg", "eeg", "fnirs")) + _check_option("method[key]", key, tuple(valids.keys())) _check_option(f"method['{key}']", method[key], valids[key]) logger.info("Setting channel interpolation method to %s.", method) idx = _picks_to_idx(self.info, list(method), exclude=(), allow_empty=True) @@ -905,24 +910,27 @@ def interpolate_bads( return self logger.info("Interpolating bad channels.") origin = _check_origin(origin, self.info) + for ch_type, interp in method.items(): + if interp == "nan": + _interpolate_bads_nan(self, ch_type, exclude=exclude) if method.get("eeg", "") == "spline": _interpolate_bads_eeg(self, origin=origin, exclude=exclude) - eeg_mne = False - elif "eeg" not in method: - eeg_mne = False - else: - eeg_mne = True - if "meg" in method or eeg_mne: + if method.get("meg", "") == "MNE" or method.get("eeg", "") == "MNE": _interpolate_bads_meeg( self, mode=mode, + meg=method.get("meg", "") == "MNE", + eeg=method.get("eeg", "") == "MNE", origin=origin, - eeg=eeg_mne, exclude=exclude, method=method, ) - if "fnirs" in method: - _interpolate_bads_nirs(self, exclude=exclude, method=method["fnirs"]) + if method.get("fnirs", "") == "nearest": + _interpolate_bads_nirs(self, exclude=exclude) + if method.get("ecog", "") == "spline": + _interpolate_bads_ecog(self, origin=origin, exclude=exclude) + if method.get("seeg", "") == "spline": + _interpolate_bads_seeg(self, exclude=exclude) if reset_bads is True: if "nan" in method.values(): diff --git a/mne/channels/interpolation.py b/mne/channels/interpolation.py index 807639b8bcf..c5607dfc93f 100644 --- a/mne/channels/interpolation.py +++ b/mne/channels/interpolation.py @@ -8,6 +8,7 @@ from numpy.polynomial.legendre import legval from scipy.linalg import pinv from scipy.spatial.distance import pdist, squareform +from scipy.interpolate import RectBivariateSpline from .._fiff.meas_info import _simplify_info from .._fiff.pick import pick_channels, pick_info, pick_types @@ -132,13 +133,13 @@ def _do_interp_dots(inst, interpolation, goods_idx, bads_idx): @verbose -def _interpolate_bads_eeg(inst, origin, exclude=None, verbose=None): +def _interpolate_bads_eeg(inst, origin, exclude=None, ecog=False, verbose=None): if exclude is None: exclude = list() bads_idx = np.zeros(len(inst.ch_names), dtype=bool) goods_idx = np.zeros(len(inst.ch_names), dtype=bool) - picks = pick_types(inst.info, meg=False, eeg=True, exclude=exclude) + picks = pick_types(inst.info, meg=False, eeg=not ecog, ecog=ecog, exclude=exclude) inst.info._check_consistency() bads_idx[picks] = [inst.ch_names[ch] in inst.info["bads"] for ch in picks] @@ -172,6 +173,11 @@ def _interpolate_bads_eeg(inst, origin, exclude=None, verbose=None): _do_interp_dots(inst, interpolation, goods_idx, bads_idx) +@verbose +def _interpolate_bads_ecog(inst, origin, exclude=None, verbose=None): + _interpolate_bads_eeg(inst, origin, exclude=exclude, ecog=True, verbose=verbose) + + def _interpolate_bads_meg( inst, mode="accurate", origin=(0.0, 0.0, 0.04), verbose=None, ref_meg=False ): @@ -180,6 +186,26 @@ def _interpolate_bads_meg( ) +@verbose +def _interpolate_bads_nan( + inst, + ch_type, + ref_meg=False, + exclude=(), + *, + verbose=None, +): + info = _simplify_info(inst.info) + picks_type = pick_types(info, ref_meg=ref_meg, exclude=exclude, **{ch_type: True}) + use_ch_names = [inst.info["ch_names"][p] for p in picks_type] + bads_type = [ch for ch in inst.info["bads"] if ch in use_ch_names] + if len(bads_type) == 0 or len(picks_type) == 0: + return + # select the bad channels to be interpolated + picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[]) + inst._data[..., picks_bad, :] = np.nan + + @verbose def _interpolate_bads_meeg( inst, @@ -213,10 +239,6 @@ def _interpolate_bads_meeg( # select the bad channels to be interpolated picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[]) - if method[ch_type] == "nan": - inst._data[picks_bad] = np.nan - continue - # do MNE based interpolation if ch_type == "eeg": picks_to = picks_type @@ -232,7 +254,7 @@ def _interpolate_bads_meeg( @verbose -def _interpolate_bads_nirs(inst, method="nearest", exclude=(), verbose=None): +def _interpolate_bads_nirs(inst, exclude=(), verbose=None): from mne.preprocessing.nirs import _validate_nirs_info if len(pick_types(inst.info, fnirs=True, exclude=())) == 0: @@ -251,25 +273,91 @@ def _interpolate_bads_nirs(inst, method="nearest", exclude=(), verbose=None): chs = [inst.info["chs"][i] for i in picks_nirs] locs3d = np.array([ch["loc"][:3] for ch in chs]) - _check_option("fnirs_method", method, ["nearest", "nan"]) - - if method == "nearest": - dist = pdist(locs3d) - dist = squareform(dist) - - for bad in picks_bad: - dists_to_bad = dist[bad] - # Ignore distances to self - dists_to_bad[dists_to_bad == 0] = np.inf - # Ignore distances to other bad channels - dists_to_bad[bads_mask] = np.inf - # Find closest remaining channels for same frequency - closest_idx = np.argmin(dists_to_bad) + (bad % 2) - inst._data[bad] = inst._data[closest_idx] - else: - assert method == "nan" - inst._data[picks_bad] = np.nan + dist = pdist(locs3d) + dist = squareform(dist) + + for bad in picks_bad: + dists_to_bad = dist[bad] + # Ignore distances to self + dists_to_bad[dists_to_bad == 0] = np.inf + # Ignore distances to other bad channels + dists_to_bad[bads_mask] = np.inf + # Find closest remaining channels for same frequency + closest_idx = np.argmin(dists_to_bad) + (bad % 2) + inst._data[bad] = inst._data[closest_idx] + # TODO: this seems like a bug because it does not respect reset_bads inst.info["bads"] = [ch for ch in inst.info["bads"] if ch in exclude] return inst + + +@verbose +def _interpolate_bads_seeg(inst, exclude=None, tol=2e-3, verbose=None): + if exclude is None: + exclude = list() + bads_idx = np.zeros(len(inst.ch_names), dtype=bool) + goods_idx = np.zeros(len(inst.ch_names), dtype=bool) + + picks = pick_types(inst.info, meg=False, seeg=True, exclude=exclude) + inst.info._check_consistency() + bads_idx[picks] = [inst.ch_names[ch] in inst.info["bads"] for ch in picks] + + if len(picks) < 3 or bads_idx.sum() == 0: + return + + goods_idx[picks] = True + goods_idx[bads_idx] = False + + pos = inst._get_channel_positions(picks) + + # Make sure only sEEG are used + bads_idx_pos = bads_idx[picks] + + # for each bad contact: + # 1) find nearest neighbor to define the electrode shaft line + # 2) find all contacts on the same line + # 3) interpolate the bad contacts + + dist = squareform(pdist(pos)) + np.fill_diagonal(dist, np.inf) + + picks_bad = list(np.where(bads_idx_pos)[0]) + while picks_bad: + bad = picks_bad[0] + n1 = pos[bad] + n2 = pos[np.argmin(dist[bad])] # 1 + # https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html + shaft_dists = np.array( + [ + 0 + if all(n0 == n1) or all(n0 == n2) + else np.linalg.norm(np.cross((n0 - n1), (n0 - n2))) + / np.linalg.norm(n2 - n1) + for n0 in pos + ] + ) + shaft = np.where(shaft_dists < tol)[0] # 2 + if shaft.size < 3: + raise RuntimeError( + f"Only {shaft.size} contact positions in a line " + f" found for {inst.ch_names[bad]}, 3 required for " + "interpolation, fix the positions or exclude this channel" + ) + bads_shaft = np.array([idx for idx in picks_bad if idx in shaft]) + goods_shaft = shaft[~np.in1d(shaft, bads_shaft)] + bads_shaft_idx = np.where(np.in1d(shaft, bads_shaft))[0] + goods_shaft_idx = np.where(~np.in1d(shaft, bads_shaft))[0] + for bad in bads_shaft: + picks_bad.remove(bad) # interpolating, remove + ts = np.array( + [ + np.dot(n1 - n0, n2 - n1) / np.linalg.norm(n2 - n1) ** 2 + for n0 in pos[shaft] + ] + ) + y = np.arange(inst._data.shape[-1]) + inst._data[bads_shaft] = RectBivariateSpline( + x=ts[goods_shaft_idx], y=y, + z=inst._data[goods_shaft] + )(x=ts[bads_shaft_idx], y=y) # 3 diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index 999e0c16402..d53174cc3fb 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -10,6 +10,7 @@ from mne import Epochs, pick_channels, pick_types, read_events from mne._fiff.constants import FIFF from mne._fiff.proj import _has_eeg_average_ref_proj +from mne.channels import make_dig_montage from mne.channels.interpolation import _make_interpolation_matrix from mne.datasets import testing from mne.io import RawArray, read_raw_ctf, read_raw_fif, read_raw_nirx @@ -329,6 +330,49 @@ def test_interpolation_nirs(): assert raw_haemo.info["bads"] == [] +@testing.requires_testing_data +def test_interpolation_ieeg(): + """Test interpolation for sEEG and ECoG.""" + raw, epochs_eeg = _load_data("eeg") + bads = ["EEG 012"] + bads_idx = np.where(np.in1d(epochs_eeg.ch_names, bads))[0] + + epochs_ecog = epochs_eeg.copy().set_channel_types( + {ch: "ecog" for ch in epochs_eeg.ch_names} + ) + epochs_ecog.info["bads"] = bads + + # check that interpolation changes the data in raw + raw_ecog = RawArray(data=epochs_ecog._data[0], info=epochs_ecog.info) + raw_before = raw_ecog._data[bads_idx] + raw_after = raw_ecog.interpolate_bads(method=dict(ecog="spline"))._data[bads_idx] + assert not np.all(raw_before == raw_after) + + epochs_seeg = epochs_eeg.copy().set_channel_types( + {ch: "seeg" for ch in epochs_eeg.ch_names} + ) + + epochs_seeg.info["bads"] = bads + + # check that interpolation changes the data in raw + raw_seeg = RawArray(data=epochs_seeg._data[0], info=epochs_seeg.info) + raw_before = raw_seeg._data[bads_idx] + with pytest.raises(RuntimeError, match="Only 2 contact positions"): + raw_seeg.interpolate_bads(method=dict(seeg="linear"))._data[ + bads_idx + ] + montage = raw_seeg.get_montage() + pos = montage.get_positions() + ch_pos = pos.pop('ch_pos') + n0 = ch_pos[epochs_seeg.ch_names[0]] + n1 = ch_pos[epochs_seeg.ch_names[1]] + for i, ch in enumerate(epochs_seeg.ch_names[2:]): + ch_pos[ch] = n0 + (n1 - n0) * (i + 2) + raw_seeg.set_montage(make_dig_montage(ch_pos, **pos)) + raw_after = raw_seeg.interpolate_bads(method=dict(seeg="linear"))._data[bads_idx] + assert not np.all(raw_before == raw_after) + + def test_nan_interpolation(raw): """Test 'nan' method for interpolating bads.""" ch_to_interp = [raw.ch_names[1]] # don't use channel 0 (type is IAS not MEG) diff --git a/mne/defaults.py b/mne/defaults.py index 8732280998f..b9e6702edec 100644 --- a/mne/defaults.py +++ b/mne/defaults.py @@ -278,7 +278,9 @@ combine_xyz="fro", allow_fixed_depth=True, ), - interpolation_method=dict(eeg="spline", meg="MNE", fnirs="nearest"), + interpolation_method=dict( + eeg="spline", meg="MNE", fnirs="nearest", ecog="spline", seeg="spline" + ), volume_options=dict( alpha=None, resolution=1.0,