From 93b160441869f5bf8dbff3bc2a61b1f2d7715f52 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 31 Jul 2023 14:15:20 -0400 Subject: [PATCH] MAINT: Improve sklearn compliance --- mne/decoding/base.py | 7 +++-- mne/decoding/receptive_field.py | 3 +- mne/decoding/tests/test_base.py | 10 +++++-- mne/decoding/tests/test_receptive_field.py | 32 ++++++++++++++++++++++ mne/decoding/time_delaying_ridge.py | 18 ++++++++++-- mne/fixes.py | 32 ---------------------- 6 files changed, 60 insertions(+), 42 deletions(-) diff --git a/mne/decoding/base.py b/mne/decoding/base.py index 247c6f89f2d..2cfd540f8ec 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -10,7 +10,7 @@ import datetime as dt import numbers from ..parallel import parallel_func -from ..fixes import BaseEstimator, is_classifier, _get_check_scoring +from ..fixes import BaseEstimator, _get_check_scoring from ..utils import warn, verbose @@ -224,6 +224,8 @@ def classes_(self): def _set_cv(cv, estimator=None, X=None, y=None): """Set the default CV depending on whether clf is classifier/regressor.""" # Detect whether classification or regression + from sklearn.base import is_classifier + if estimator in ["classifier", "regressor"]: est_is_classifier = estimator == "classifier" else: @@ -440,8 +442,7 @@ def cross_val_multiscore( Array of scores of the estimator for each run of the cross validation. """ # This code is copied from sklearn - - from sklearn.base import clone + from sklearn.base import clone, is_classifier from sklearn.utils import indexable from sklearn.model_selection._split import check_cv diff --git a/mne/decoding/receptive_field.py b/mne/decoding/receptive_field.py index 6fa38a4f72f..5473e0c10b7 100644 --- a/mne/decoding/receptive_field.py +++ b/mne/decoding/receptive_field.py @@ -9,7 +9,6 @@ from .base import get_coef, BaseEstimator, _check_estimator from .time_delaying_ridge import TimeDelayingRidge -from ..fixes import is_regressor from ..utils import _validate_type, verbose, fill_doc @@ -183,7 +182,7 @@ def fit(self, X, y): "scoring must be one of %s, got" "%s " % (sorted(_SCORERS.keys()), self.scoring) ) - from sklearn.base import clone + from sklearn.base import clone, is_regressor X, y, _, self._y_dim = self._check_dimensions(X, y) diff --git a/mne/decoding/tests/test_base.py b/mne/decoding/tests/test_base.py index c7773a217d4..eaf0164006d 100644 --- a/mne/decoding/tests/test_base.py +++ b/mne/decoding/tests/test_base.py @@ -14,7 +14,6 @@ import pytest from mne import create_info, EpochsArray -from mne.fixes import is_regressor, is_classifier from mne.utils import requires_sklearn from mne.decoding.base import ( _get_inverse_funcs, @@ -69,7 +68,12 @@ def _make_data(n_samples=1000, n_features=5, n_targets=3): @requires_sklearn def test_get_coef(): """Test getting linear coefficients (filters/patterns) from estimators.""" - from sklearn.base import TransformerMixin, BaseEstimator + from sklearn.base import ( + TransformerMixin, + BaseEstimator, + is_classifier, + is_regressor, + ) from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler from sklearn import svm @@ -336,6 +340,7 @@ def test_get_coef_multiclass_full(n_classes, n_channels, n_times): def test_linearmodel(): """Test LinearModel class for computing filters and patterns.""" # check categorical target fit in standard linear model + from sklearn.utils.estimator_checks import check_estimator from sklearn.linear_model import LinearRegression rng = np.random.RandomState(0) @@ -388,6 +393,7 @@ def test_linearmodel(): with pytest.raises(ValueError): wrong_y = rng.rand(n, n_features, 99) clf.fit(X, wrong_y) + check_estimator(clf) @requires_sklearn diff --git a/mne/decoding/tests/test_receptive_field.py b/mne/decoding/tests/test_receptive_field.py index 9a993b43669..c3014a98d7d 100644 --- a/mne/decoding/tests/test_receptive_field.py +++ b/mne/decoding/tests/test_receptive_field.py @@ -596,3 +596,35 @@ def test_linalg_warning(): (RuntimeWarning, UserWarning), match="[Singular|scipy.linalg.solve]" ): rf.fit(y, X) + + +@requires_sklearn +def test_tdr_sklearn_compliance(): + """Test sklearn estimator compliance.""" + from sklearn.utils.estimator_checks import check_estimator + + tdr = TimeDelayingRidge(0, 10, 1.0, 0.1, "laplacian", n_jobs=1) + # We don't actually comply with a bunch of the regressor specs :( + ignores = ( + "check_supervised_y_no_nan", + "check_regressor", + "check_parameters_default_constructible", + "check_estimators_unfitted", + "_invariance", + "check_fit2d_1sample", + ) + for est, check in check_estimator(tdr, generate_only=True): + if any(ignore in str(check) for ignore in ignores): + continue + check(est) + + +@requires_sklearn +def test_rf_sklearn_compliance(): + """Test sklearn RF compliance.""" + from sklearn.linear_model import Ridge + from sklearn.utils.estimator_checks import check_estimator + + rf = ReceptiveField(-0.1, 0.2, 1000.0, estimator=Ridge(), patterns=True) + for est, check in check_estimator(rf, generate_only=True): + check(est) diff --git a/mne/decoding/time_delaying_ridge.py b/mne/decoding/time_delaying_ridge.py index 2299aa5d861..2853cc98b36 100644 --- a/mne/decoding/time_delaying_ridge.py +++ b/mne/decoding/time_delaying_ridge.py @@ -10,7 +10,7 @@ from ..cuda import _setup_cuda_fft_multiply_repeated from ..filter import next_fast_len from ..fixes import jit -from ..utils import warn, ProgressBar, logger +from ..utils import warn, ProgressBar, logger, _validate_type, _check_option def _compute_corrs( @@ -301,6 +301,9 @@ def __init__( self.edge_correction = edge_correction self.n_jobs = n_jobs + def _more_tags(self): + return {"no_validation": True} + @property def _smin(self): return int(round(self.tmin * self.sfreq)) @@ -324,12 +327,21 @@ def fit(self, X, y): self : instance of TimeDelayingRidge Returns the modified instance. """ + _validate_type(X, "array-like", "X") + _validate_type(y, "array-like", "y") + X = np.asarray(X, dtype=float) + y = np.asarray(y, dtype=float) if X.ndim == 3: assert y.ndim == 3 assert X.shape[:2] == y.shape[:2] else: - assert X.ndim == 2 and y.ndim == 2 - assert X.shape[0] == y.shape[0] + if X.ndim == 1: + X = X[:, np.newaxis] + if y.ndim == 1: + y = y[:, np.newaxis] + assert X.ndim == 2 + assert y.ndim == 2 + _check_option("y.shape[0]", y.shape[0], (X.shape[0],)) # These are split into two functions because it's possible that we # might want to allow people to do them separately (e.g., to test # different regularization parameters). diff --git a/mne/fixes.py b/mne/fixes.py index c05dfaec344..b89e6cd0f84 100644 --- a/mne/fixes.py +++ b/mne/fixes.py @@ -176,38 +176,6 @@ def _read_volume_info(fobj): # adapted from scikit-learn -def is_classifier(estimator): - """Returns True if the given estimator is (probably) a classifier. - - Parameters - ---------- - estimator : object - Estimator object to test. - - Returns - ------- - out : bool - True if estimator is a classifier and False otherwise. - """ - return getattr(estimator, "_estimator_type", None) == "classifier" - - -def is_regressor(estimator): - """Returns True if the given estimator is (probably) a regressor. - - Parameters - ---------- - estimator : object - Estimator object to test. - - Returns - ------- - out : bool - True if estimator is a regressor and False otherwise. - """ - return getattr(estimator, "_estimator_type", None) == "regressor" - - _DEFAULT_TAGS = { "non_deterministic": False, "requires_positive_X": False,