diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 3cb7b62b06c..21650a1ef68 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -870,11 +870,11 @@ def interpolate_bads( .. versionadded:: 0.9.0 """ from .interpolation import ( - _interpolate_bads_nan, + _interpolate_bads_ecog, _interpolate_bads_eeg, _interpolate_bads_meeg, + _interpolate_bads_nan, _interpolate_bads_nirs, - _interpolate_bads_ecog, _interpolate_bads_seeg, ) diff --git a/mne/channels/interpolation.py b/mne/channels/interpolation.py index c5607dfc93f..2af029ba5e6 100644 --- a/mne/channels/interpolation.py +++ b/mne/channels/interpolation.py @@ -6,14 +6,14 @@ import numpy as np from numpy.polynomial.legendre import legval +from scipy.interpolate import RectBivariateSpline from scipy.linalg import pinv from scipy.spatial.distance import pdist, squareform -from scipy.interpolate import RectBivariateSpline from .._fiff.meas_info import _simplify_info from .._fiff.pick import pick_channels, pick_info, pick_types from ..surface import _normalize_vectors -from ..utils import _check_option, _validate_type, logger, verbose, warn +from ..utils import _validate_type, logger, verbose, warn def _calc_h(cosang, stiffness=4, n_legendre_terms=50): @@ -358,6 +358,5 @@ def _interpolate_bads_seeg(inst, exclude=None, tol=2e-3, verbose=None): ) y = np.arange(inst._data.shape[-1]) inst._data[bads_shaft] = RectBivariateSpline( - x=ts[goods_shaft_idx], y=y, - z=inst._data[goods_shaft] + x=ts[goods_shaft_idx], y=y, z=inst._data[goods_shaft] )(x=ts[bads_shaft_idx], y=y) # 3 diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index d53174cc3fb..48e6b9c2409 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -358,12 +358,10 @@ def test_interpolation_ieeg(): raw_seeg = RawArray(data=epochs_seeg._data[0], info=epochs_seeg.info) raw_before = raw_seeg._data[bads_idx] with pytest.raises(RuntimeError, match="Only 2 contact positions"): - raw_seeg.interpolate_bads(method=dict(seeg="linear"))._data[ - bads_idx - ] + raw_seeg.interpolate_bads(method=dict(seeg="linear"))._data[bads_idx] montage = raw_seeg.get_montage() pos = montage.get_positions() - ch_pos = pos.pop('ch_pos') + ch_pos = pos.pop("ch_pos") n0 = ch_pos[epochs_seeg.ch_names[0]] n1 = ch_pos[epochs_seeg.ch_names[1]] for i, ch in enumerate(epochs_seeg.ch_names[2:]):