Skip to content

Commit

Permalink
FIX: Better
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Jul 31, 2023
1 parent c100110 commit 3cc4407
Show file tree
Hide file tree
Showing 25 changed files with 81 additions and 122 deletions.
7 changes: 2 additions & 5 deletions mne/channels/tests/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
Epochs,
)
from mne.datasets import testing
from mne.utils import requires_pandas, requires_version
from mne.parallel import parallel_func

io_dir = Path(__file__).parent.parent.parent / "io"
Expand Down Expand Up @@ -410,10 +409,10 @@ def test_get_set_sensor_positions():
assert_array_equal(raw1.info["chs"][13]["loc"], raw2.info["chs"][13]["loc"])


@requires_version("pymatreader")
@testing.requires_testing_data
def test_1020_selection():
"""Test making a 10/20 selection dict."""
pytest.importorskip("pymatreader")
raw_fname = testing_path / "EEGLAB" / "test_raw.set"
loc_fname = testing_path / "EEGLAB" / "test_chans.locs"
raw = read_raw_eeglab(raw_fname, preload=True)
Expand Down Expand Up @@ -676,11 +675,9 @@ def test_combine_channels():
assert len(record) == 3


@requires_pandas
def test_combine_channels_metadata():
"""Test if metadata is correctly retained in combined object."""
import pandas as pd

pd = pytest.importorskip("pandas")
raw = read_raw_fif(raw_fname, preload=True)
epochs = Epochs(raw, read_events(eve_fname), preload=True)

Expand Down
10 changes: 3 additions & 7 deletions mne/decoding/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import pytest

from mne import create_info, EpochsArray
from mne.utils import requires_sklearn
from mne.decoding.base import (
_get_inverse_funcs,
LinearModel,
Expand All @@ -26,6 +25,9 @@
from mne.decoding import Scaler, TransformerMixin, Vectorizer, GeneralizingEstimator


pytest.importorskip("sklearn")


def _make_data(n_samples=1000, n_features=5, n_targets=3):
"""Generate some testing data.
Expand Down Expand Up @@ -65,7 +67,6 @@ def _make_data(n_samples=1000, n_features=5, n_targets=3):
return X, Y, A


@requires_sklearn
def test_get_coef():
"""Test getting linear coefficients (filters/patterns) from estimators."""
from sklearn.base import (
Expand Down Expand Up @@ -204,7 +205,6 @@ def transform(self, X):
inverse_transform = transform


@requires_sklearn
@pytest.mark.parametrize("inverse", (True, False))
@pytest.mark.parametrize(
"Scale, kwargs",
Expand Down Expand Up @@ -241,7 +241,6 @@ def test_get_coef_inverse_transform(inverse, Scale, kwargs):
assert_array_equal(filters_t, filters[:, t])


@requires_sklearn
@pytest.mark.parametrize("n_features", [1, 5])
@pytest.mark.parametrize("n_targets", [1, 3])
def test_get_coef_multiclass(n_features, n_targets):
Expand Down Expand Up @@ -288,7 +287,6 @@ def test_get_coef_multiclass(n_features, n_targets):
lm.fit(X, Y, sample_weight=np.ones(len(Y)))


@requires_sklearn
@pytest.mark.parametrize(
"n_classes, n_channels, n_times",
[
Expand Down Expand Up @@ -336,7 +334,6 @@ def test_get_coef_multiclass_full(n_classes, n_channels, n_times):
assert_allclose(patterns[:, 1:], 0.0, atol=1e-7) # no other channels useful


@requires_sklearn
def test_linearmodel():
"""Test LinearModel class for computing filters and patterns."""
# check categorical target fit in standard linear model
Expand Down Expand Up @@ -394,7 +391,6 @@ def test_linearmodel():
clf.fit(X, wrong_y)


@requires_sklearn
def test_cross_val_multiscore():
"""Test cross_val_multiscore for computing scores on decoding over time."""
from sklearn.model_selection import KFold, StratifiedKFold, cross_val_score
Expand Down
5 changes: 2 additions & 3 deletions mne/decoding/tests/test_csp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from mne import io, Epochs, read_events, pick_types
from mne.decoding.csp import CSP, _ajd_pham, SPoC
from mne.utils import requires_sklearn

data_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data"
raw_fname = data_dir / "test_raw.fif"
Expand Down Expand Up @@ -245,9 +244,9 @@ def test_csp():
assert np.abs(corr) > 0.95


@requires_sklearn
def test_regularized_csp():
"""Test Common Spatial Patterns algorithm using regularized covariance."""
pytest.importorskip("sklearn")
raw = io.read_raw_fif(raw_fname)
events = read_events(event_name)
picks = pick_types(
Expand Down Expand Up @@ -281,9 +280,9 @@ def test_regularized_csp():
assert sources.shape[1] == n_components


@requires_sklearn
def test_csp_pipeline():
"""Test if CSP works in a pipeline."""
pytest.importorskip("sklearn")
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline

Expand Down
4 changes: 2 additions & 2 deletions mne/decoding/tests/test_ems.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pytest

from mne import io, Epochs, read_events, pick_types
from mne.utils import requires_sklearn
from mne.decoding import compute_ems, EMS

data_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data"
Expand All @@ -18,8 +17,9 @@
tmin, tmax = -0.2, 0.5
event_id = dict(aud_l=1, vis_l=3)

pytest.importorskip("sklearn")


@requires_sklearn
def test_ems():
"""Test event-matched spatial filters."""
from sklearn.model_selection import StratifiedKFold
Expand Down
17 changes: 8 additions & 9 deletions mne/decoding/tests/test_receptive_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from numpy.fft import rfft, irfft
from numpy.testing import assert_array_equal, assert_allclose, assert_equal

from mne.utils import requires_sklearn
from mne.decoding import ReceptiveField, TimeDelayingRidge
from mne.decoding.receptive_field import (
_delay_time_series,
Expand Down Expand Up @@ -78,10 +77,10 @@ def test_compute_reg_neighbors():
)


@requires_sklearn
def test_rank_deficiency():
"""Test signals that are rank deficient."""
# See GH#4253
pytest.importorskip("sklearn")
from sklearn.linear_model import Ridge

N = 256
Expand Down Expand Up @@ -174,9 +173,9 @@ def test_time_delay():

@pytest.mark.slowtest # slow on Azure
@pytest.mark.parametrize("n_jobs", n_jobs_test)
@requires_sklearn
def test_receptive_field_basic(n_jobs):
"""Test model prep and fitting."""
pytest.importorskip("sklearn")
from sklearn.linear_model import Ridge

# Make sure estimator pulling works
Expand Down Expand Up @@ -372,9 +371,9 @@ def test_time_delaying_fast_calc(n_jobs):


@pytest.mark.parametrize("n_jobs", n_jobs_test)
@requires_sklearn
def test_receptive_field_1d(n_jobs):
"""Test that the fast solving works like Ridge."""
pytest.importorskip("sklearn")
from sklearn.linear_model import Ridge

rng = np.random.RandomState(0)
Expand Down Expand Up @@ -433,9 +432,9 @@ def test_receptive_field_1d(n_jobs):


@pytest.mark.parametrize("n_jobs", n_jobs_test)
@requires_sklearn
def test_receptive_field_nd(n_jobs):
"""Test multidimensional support."""
pytest.importorskip("sklearn")
from sklearn.linear_model import Ridge

# multidimensional
Expand Down Expand Up @@ -552,9 +551,9 @@ def _make_data(n_feats, n_targets, n_samples, tmin, tmax):
return X, y


@requires_sklearn
def test_inverse_coef():
"""Test inverse coefficients computation."""
pytest.importorskip("sklearn")
from sklearn.linear_model import Ridge

tmin, tmax = 0.0, 10.0
Expand Down Expand Up @@ -583,9 +582,9 @@ def test_inverse_coef():
assert_allclose(np.dot(c0, c1.T), np.eye(c0.shape[0]), atol=0.2)


@requires_sklearn
def test_linalg_warning():
"""Test that warnings are issued when no regularization is applied."""
pytest.importorskip("sklearn")
from sklearn.linear_model import Ridge

n_feats, n_targets, n_samples = 5, 60, 50
Expand All @@ -598,9 +597,9 @@ def test_linalg_warning():
rf.fit(y, X)


@requires_sklearn
def test_tdr_sklearn_compliance():
"""Test sklearn estimator compliance."""
pytest.importorskip("sklearn")
from sklearn.utils.estimator_checks import check_estimator

tdr = TimeDelayingRidge(0, 10, 1.0, 0.1, "laplacian", n_jobs=1)
Expand All @@ -619,9 +618,9 @@ def test_tdr_sklearn_compliance():
check(est)


@requires_sklearn
def test_rf_sklearn_compliance():
"""Test sklearn RF compliance."""
pytest.importorskip("sklearn")
from sklearn.linear_model import Ridge
from sklearn.utils.estimator_checks import check_estimator

Expand Down
9 changes: 3 additions & 6 deletions mne/decoding/tests/test_search_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from numpy.testing import assert_array_equal, assert_equal
import pytest

from mne.utils import requires_sklearn, _record_warnings, use_log_level
from mne.utils import _record_warnings, use_log_level
from mne.decoding.search_light import SlidingEstimator, GeneralizingEstimator
from mne.decoding.transformer import Vectorizer

pytest.importorskip("sklearn")


def make_data():
"""Make data."""
Expand All @@ -25,7 +27,6 @@ def make_data():
return X, y


@requires_sklearn
def test_search_light():
"""Test SlidingEstimator."""
from sklearn.linear_model import Ridge, LogisticRegression
Expand Down Expand Up @@ -167,7 +168,6 @@ def transform(self, X):
assert isinstance(pipe.estimators_[0], BaggingClassifier)


@requires_sklearn
def test_generalization_light():
"""Test GeneralizingEstimator."""
from sklearn.pipeline import make_pipeline
Expand Down Expand Up @@ -254,7 +254,6 @@ def test_generalization_light():
assert_array_equal(y_preds[0], y_preds[1])


@requires_sklearn
@pytest.mark.parametrize(
"n_jobs, verbose", [(1, False), (2, False), (1, True), (2, "info")]
)
Expand All @@ -280,7 +279,6 @@ def test_verbose_arg(capsys, n_jobs, verbose):
assert any(len(channel) > 0 for channel in (stdout, stderr))


@requires_sklearn
def test_cross_val_predict():
"""Test cross_val_predict with predict_proba."""
from sklearn.linear_model import LinearRegression
Expand Down Expand Up @@ -317,7 +315,6 @@ def predict_proba(self, X):


@pytest.mark.slowtest
@requires_sklearn
def test_sklearn_compliance():
"""Test LinearModel compliance with sklearn."""
from sklearn.utils.estimator_checks import check_estimator
Expand Down
3 changes: 1 addition & 2 deletions mne/decoding/tests/test_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from mne import io
from mne.time_frequency import psd_array_welch
from mne.decoding.ssd import SSD
from mne.utils import requires_sklearn
from mne.filter import filter_data
from mne import create_info
from mne.decoding import CSP
Expand Down Expand Up @@ -296,9 +295,9 @@ def test_ssd_epoched_data():
)


@requires_sklearn
def test_ssd_pipeline():
"""Test if SSD works in a pipeline."""
pytest.importorskip("sklearn")
from sklearn.pipeline import Pipeline

sf = 250
Expand Down
3 changes: 1 addition & 2 deletions mne/decoding/tests/test_time_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
from numpy.testing import assert_array_equal
import pytest

from mne.utils import requires_sklearn
from mne.decoding.time_frequency import TimeFrequency


@requires_sklearn
def test_timefrequency():
"""Test TimeFrequency."""
pytest.importorskip("sklearn")
from sklearn.base import clone

# Init
Expand Down
4 changes: 2 additions & 2 deletions mne/decoding/tests/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
TemporalFilter,
)
from mne.defaults import DEFAULTS
from mne.utils import requires_sklearn, check_version, use_log_level
from mne.utils import check_version, use_log_level

tmin, tmax = -0.2, 0.5
event_id = dict(aud_l=1, vis_l=3)
Expand Down Expand Up @@ -217,9 +217,9 @@ def test_vectorizer():
pytest.raises(ValueError, vect.inverse_transform, np.random.rand(102, 12, 12))


@requires_sklearn
def test_unsupervised_spatial_filter():
"""Test unsupervised spatial filter."""
pytest.importorskip("sklearn")
from sklearn.decomposition import PCA
from sklearn.kernel_ridge import KernelRidge

Expand Down
4 changes: 2 additions & 2 deletions mne/io/tests/test_what.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@
from mne.datasets import testing
from mne.io import RawArray
from mne.preprocessing import ICA
from mne.utils import requires_sklearn, _record_warnings
from mne.utils import _record_warnings

data_path = testing.data_path(download=False)


@pytest.mark.slowtest
@requires_sklearn
@testing.requires_testing_data
def test_what(tmp_path, verbose_debug):
"""Test mne.what."""
pytest.importorskip("sklearn")
# ICA
ica = ICA(max_iter=1)
raw = RawArray(np.random.RandomState(0).randn(3, 10), create_info(3, 1000.0, "eeg"))
Expand Down
Loading

0 comments on commit 3cc4407

Please sign in to comment.