Skip to content

Commit

Permalink
[ENH] Add support for ieeg interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
alexrockhill committed Jan 4, 2024
1 parent 596122d commit 3264e36
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 36 deletions.
28 changes: 18 additions & 10 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,9 +870,12 @@ def interpolate_bads(
.. versionadded:: 0.9.0
"""
from .interpolation import (
_interpolate_bads_nan,
_interpolate_bads_eeg,
_interpolate_bads_meeg,
_interpolate_bads_nirs,
_interpolate_bads_ecog,
_interpolate_bads_seeg,
)

_check_preload(self, "interpolation")
Expand All @@ -894,9 +897,11 @@ def interpolate_bads(
"eeg": ("spline", "MNE", "nan"),
"meg": ("MNE", "nan"),
"fnirs": ("nearest", "nan"),
"ecog": ("spline", "nan"),
"seeg": ("spline", "nan"),
}
for key in method:
_check_option("method[key]", key, ("meg", "eeg", "fnirs"))
_check_option("method[key]", key, tuple(valids.keys()))
_check_option(f"method['{key}']", method[key], valids[key])
logger.info("Setting channel interpolation method to %s.", method)
idx = _picks_to_idx(self.info, list(method), exclude=(), allow_empty=True)
Expand All @@ -905,24 +910,27 @@ def interpolate_bads(
return self
logger.info("Interpolating bad channels.")
origin = _check_origin(origin, self.info)
for ch_type, interp in method.items():
if interp == "nan":
_interpolate_bads_nan(self, ch_type, exclude=exclude)
if method.get("eeg", "") == "spline":
_interpolate_bads_eeg(self, origin=origin, exclude=exclude)
eeg_mne = False
elif "eeg" not in method:
eeg_mne = False
else:
eeg_mne = True
if "meg" in method or eeg_mne:
if method.get("meg", "") == "MNE" or method.get("eeg", "") == "MNE":
_interpolate_bads_meeg(
self,
mode=mode,
meg=method.get("meg", "") == "MNE",
eeg=method.get("eeg", "") == "MNE",
origin=origin,
eeg=eeg_mne,
exclude=exclude,
method=method,
)
if "fnirs" in method:
_interpolate_bads_nirs(self, exclude=exclude, method=method["fnirs"])
if method.get("fnirs", "") == "nearest":
_interpolate_bads_nirs(self, exclude=exclude)
if method.get("ecog", "") == "spline":
_interpolate_bads_ecog(self, origin=origin, exclude=exclude)
if method.get("seeg", "") == "spline":
_interpolate_bads_seeg(self, exclude=exclude)

if reset_bads is True:
if "nan" in method.values():
Expand Down
138 changes: 113 additions & 25 deletions mne/channels/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from numpy.polynomial.legendre import legval
from scipy.linalg import pinv
from scipy.spatial.distance import pdist, squareform
from scipy.interpolate import RectBivariateSpline

from .._fiff.meas_info import _simplify_info
from .._fiff.pick import pick_channels, pick_info, pick_types
Expand Down Expand Up @@ -132,13 +133,13 @@ def _do_interp_dots(inst, interpolation, goods_idx, bads_idx):


@verbose
def _interpolate_bads_eeg(inst, origin, exclude=None, verbose=None):
def _interpolate_bads_eeg(inst, origin, exclude=None, ecog=False, verbose=None):
if exclude is None:
exclude = list()
bads_idx = np.zeros(len(inst.ch_names), dtype=bool)
goods_idx = np.zeros(len(inst.ch_names), dtype=bool)

picks = pick_types(inst.info, meg=False, eeg=True, exclude=exclude)
picks = pick_types(inst.info, meg=False, eeg=not ecog, ecog=ecog, exclude=exclude)
inst.info._check_consistency()
bads_idx[picks] = [inst.ch_names[ch] in inst.info["bads"] for ch in picks]

Expand Down Expand Up @@ -172,6 +173,11 @@ def _interpolate_bads_eeg(inst, origin, exclude=None, verbose=None):
_do_interp_dots(inst, interpolation, goods_idx, bads_idx)


@verbose
def _interpolate_bads_ecog(inst, origin, exclude=None, verbose=None):
_interpolate_bads_eeg(inst, origin, exclude=exclude, ecog=True, verbose=verbose)


def _interpolate_bads_meg(
inst, mode="accurate", origin=(0.0, 0.0, 0.04), verbose=None, ref_meg=False
):
Expand All @@ -180,6 +186,26 @@ def _interpolate_bads_meg(
)


@verbose
def _interpolate_bads_nan(
inst,
ch_type,
ref_meg=False,
exclude=(),
*,
verbose=None,
):
info = _simplify_info(inst.info)
picks_type = pick_types(info, ref_meg=ref_meg, exclude=exclude, **{ch_type: True})
use_ch_names = [inst.info["ch_names"][p] for p in picks_type]
bads_type = [ch for ch in inst.info["bads"] if ch in use_ch_names]
if len(bads_type) == 0 or len(picks_type) == 0:
return
# select the bad channels to be interpolated
picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[])
inst._data[..., picks_bad, :] = np.nan


@verbose
def _interpolate_bads_meeg(
inst,
Expand Down Expand Up @@ -213,10 +239,6 @@ def _interpolate_bads_meeg(
# select the bad channels to be interpolated
picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[])

if method[ch_type] == "nan":
inst._data[picks_bad] = np.nan
continue

# do MNE based interpolation
if ch_type == "eeg":
picks_to = picks_type
Expand All @@ -232,7 +254,7 @@ def _interpolate_bads_meeg(


@verbose
def _interpolate_bads_nirs(inst, method="nearest", exclude=(), verbose=None):
def _interpolate_bads_nirs(inst, exclude=(), verbose=None):
from mne.preprocessing.nirs import _validate_nirs_info

if len(pick_types(inst.info, fnirs=True, exclude=())) == 0:
Expand All @@ -251,25 +273,91 @@ def _interpolate_bads_nirs(inst, method="nearest", exclude=(), verbose=None):
chs = [inst.info["chs"][i] for i in picks_nirs]
locs3d = np.array([ch["loc"][:3] for ch in chs])

_check_option("fnirs_method", method, ["nearest", "nan"])

if method == "nearest":
dist = pdist(locs3d)
dist = squareform(dist)

for bad in picks_bad:
dists_to_bad = dist[bad]
# Ignore distances to self
dists_to_bad[dists_to_bad == 0] = np.inf
# Ignore distances to other bad channels
dists_to_bad[bads_mask] = np.inf
# Find closest remaining channels for same frequency
closest_idx = np.argmin(dists_to_bad) + (bad % 2)
inst._data[bad] = inst._data[closest_idx]
else:
assert method == "nan"
inst._data[picks_bad] = np.nan
dist = pdist(locs3d)
dist = squareform(dist)

for bad in picks_bad:
dists_to_bad = dist[bad]
# Ignore distances to self
dists_to_bad[dists_to_bad == 0] = np.inf
# Ignore distances to other bad channels
dists_to_bad[bads_mask] = np.inf
# Find closest remaining channels for same frequency
closest_idx = np.argmin(dists_to_bad) + (bad % 2)
inst._data[bad] = inst._data[closest_idx]

# TODO: this seems like a bug because it does not respect reset_bads
inst.info["bads"] = [ch for ch in inst.info["bads"] if ch in exclude]

return inst


@verbose
def _interpolate_bads_seeg(inst, exclude=None, tol=2e-3, verbose=None):
if exclude is None:
exclude = list()
bads_idx = np.zeros(len(inst.ch_names), dtype=bool)
goods_idx = np.zeros(len(inst.ch_names), dtype=bool)

picks = pick_types(inst.info, meg=False, seeg=True, exclude=exclude)
inst.info._check_consistency()
bads_idx[picks] = [inst.ch_names[ch] in inst.info["bads"] for ch in picks]

if len(picks) < 3 or bads_idx.sum() == 0:
return

goods_idx[picks] = True
goods_idx[bads_idx] = False

pos = inst._get_channel_positions(picks)

# Make sure only sEEG are used
bads_idx_pos = bads_idx[picks]

# for each bad contact:
# 1) find nearest neighbor to define the electrode shaft line
# 2) find all contacts on the same line
# 3) interpolate the bad contacts

dist = squareform(pdist(pos))
np.fill_diagonal(dist, np.inf)

picks_bad = list(np.where(bads_idx_pos)[0])
while picks_bad:
bad = picks_bad[0]
n1 = pos[bad]
n2 = pos[np.argmin(dist[bad])] # 1
# https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html
shaft_dists = np.array(
[
0
if all(n0 == n1) or all(n0 == n2)
else np.linalg.norm(np.cross((n0 - n1), (n0 - n2)))
/ np.linalg.norm(n2 - n1)
for n0 in pos
]
)
shaft = np.where(shaft_dists < tol)[0] # 2
if shaft.size < 3:
raise RuntimeError(
f"Only {shaft.size} contact positions in a line "
f" found for {inst.ch_names[bad]}, 3 required for "
"interpolation, fix the positions or exclude this channel"
)
bads_shaft = np.array([idx for idx in picks_bad if idx in shaft])
goods_shaft = shaft[~np.in1d(shaft, bads_shaft)]
bads_shaft_idx = np.where(np.in1d(shaft, bads_shaft))[0]
goods_shaft_idx = np.where(~np.in1d(shaft, bads_shaft))[0]
for bad in bads_shaft:
picks_bad.remove(bad) # interpolating, remove
ts = np.array(
[
np.dot(n1 - n0, n2 - n1) / np.linalg.norm(n2 - n1) ** 2
for n0 in pos[shaft]
]
)
y = np.arange(inst._data.shape[-1])
inst._data[bads_shaft] = RectBivariateSpline(
x=ts[goods_shaft_idx], y=y,
z=inst._data[goods_shaft]
)(x=ts[bads_shaft_idx], y=y) # 3
44 changes: 44 additions & 0 deletions mne/channels/tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mne import Epochs, pick_channels, pick_types, read_events
from mne._fiff.constants import FIFF
from mne._fiff.proj import _has_eeg_average_ref_proj
from mne.channels import make_dig_montage
from mne.channels.interpolation import _make_interpolation_matrix
from mne.datasets import testing
from mne.io import RawArray, read_raw_ctf, read_raw_fif, read_raw_nirx
Expand Down Expand Up @@ -329,6 +330,49 @@ def test_interpolation_nirs():
assert raw_haemo.info["bads"] == []


@testing.requires_testing_data
def test_interpolation_ieeg():
"""Test interpolation for sEEG and ECoG."""
raw, epochs_eeg = _load_data("eeg")
bads = ["EEG 012"]
bads_idx = np.where(np.in1d(epochs_eeg.ch_names, bads))[0]

epochs_ecog = epochs_eeg.copy().set_channel_types(
{ch: "ecog" for ch in epochs_eeg.ch_names}
)
epochs_ecog.info["bads"] = bads

# check that interpolation changes the data in raw
raw_ecog = RawArray(data=epochs_ecog._data[0], info=epochs_ecog.info)
raw_before = raw_ecog._data[bads_idx]
raw_after = raw_ecog.interpolate_bads(method=dict(ecog="spline"))._data[bads_idx]
assert not np.all(raw_before == raw_after)

epochs_seeg = epochs_eeg.copy().set_channel_types(
{ch: "seeg" for ch in epochs_eeg.ch_names}
)

epochs_seeg.info["bads"] = bads

# check that interpolation changes the data in raw
raw_seeg = RawArray(data=epochs_seeg._data[0], info=epochs_seeg.info)
raw_before = raw_seeg._data[bads_idx]
with pytest.raises(RuntimeError, match="Only 2 contact positions"):
raw_seeg.interpolate_bads(method=dict(seeg="linear"))._data[
bads_idx
]
montage = raw_seeg.get_montage()
pos = montage.get_positions()
ch_pos = pos.pop('ch_pos')
n0 = ch_pos[epochs_seeg.ch_names[0]]
n1 = ch_pos[epochs_seeg.ch_names[1]]
for i, ch in enumerate(epochs_seeg.ch_names[2:]):
ch_pos[ch] = n0 + (n1 - n0) * (i + 2)
raw_seeg.set_montage(make_dig_montage(ch_pos, **pos))
raw_after = raw_seeg.interpolate_bads(method=dict(seeg="linear"))._data[bads_idx]
assert not np.all(raw_before == raw_after)


def test_nan_interpolation(raw):
"""Test 'nan' method for interpolating bads."""
ch_to_interp = [raw.ch_names[1]] # don't use channel 0 (type is IAS not MEG)
Expand Down
4 changes: 3 additions & 1 deletion mne/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,9 @@
combine_xyz="fro",
allow_fixed_depth=True,
),
interpolation_method=dict(eeg="spline", meg="MNE", fnirs="nearest"),
interpolation_method=dict(
eeg="spline", meg="MNE", fnirs="nearest", ecog="spline", seeg="spline"
),
volume_options=dict(
alpha=None,
resolution=1.0,
Expand Down

0 comments on commit 3264e36

Please sign in to comment.