Skip to content

Commit

Permalink
MAINT: Improve sklearn compliance
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Jul 31, 2023
1 parent a560999 commit 93b1604
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 42 deletions.
7 changes: 4 additions & 3 deletions mne/decoding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import datetime as dt
import numbers
from ..parallel import parallel_func
from ..fixes import BaseEstimator, is_classifier, _get_check_scoring
from ..fixes import BaseEstimator, _get_check_scoring
from ..utils import warn, verbose


Expand Down Expand Up @@ -224,6 +224,8 @@ def classes_(self):
def _set_cv(cv, estimator=None, X=None, y=None):
"""Set the default CV depending on whether clf is classifier/regressor."""
# Detect whether classification or regression
from sklearn.base import is_classifier

if estimator in ["classifier", "regressor"]:
est_is_classifier = estimator == "classifier"
else:
Expand Down Expand Up @@ -440,8 +442,7 @@ def cross_val_multiscore(
Array of scores of the estimator for each run of the cross validation.
"""
# This code is copied from sklearn

from sklearn.base import clone
from sklearn.base import clone, is_classifier
from sklearn.utils import indexable
from sklearn.model_selection._split import check_cv

Expand Down
3 changes: 1 addition & 2 deletions mne/decoding/receptive_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from .base import get_coef, BaseEstimator, _check_estimator
from .time_delaying_ridge import TimeDelayingRidge
from ..fixes import is_regressor
from ..utils import _validate_type, verbose, fill_doc


Expand Down Expand Up @@ -183,7 +182,7 @@ def fit(self, X, y):
"scoring must be one of %s, got"
"%s " % (sorted(_SCORERS.keys()), self.scoring)
)
from sklearn.base import clone
from sklearn.base import clone, is_regressor

X, y, _, self._y_dim = self._check_dimensions(X, y)

Expand Down
10 changes: 8 additions & 2 deletions mne/decoding/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import pytest

from mne import create_info, EpochsArray
from mne.fixes import is_regressor, is_classifier
from mne.utils import requires_sklearn
from mne.decoding.base import (
_get_inverse_funcs,
Expand Down Expand Up @@ -69,7 +68,12 @@ def _make_data(n_samples=1000, n_features=5, n_targets=3):
@requires_sklearn
def test_get_coef():
"""Test getting linear coefficients (filters/patterns) from estimators."""
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.base import (
TransformerMixin,
BaseEstimator,
is_classifier,
is_regressor,
)
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn import svm
Expand Down Expand Up @@ -336,6 +340,7 @@ def test_get_coef_multiclass_full(n_classes, n_channels, n_times):
def test_linearmodel():
"""Test LinearModel class for computing filters and patterns."""
# check categorical target fit in standard linear model
from sklearn.utils.estimator_checks import check_estimator
from sklearn.linear_model import LinearRegression

rng = np.random.RandomState(0)
Expand Down Expand Up @@ -388,6 +393,7 @@ def test_linearmodel():
with pytest.raises(ValueError):
wrong_y = rng.rand(n, n_features, 99)
clf.fit(X, wrong_y)
check_estimator(clf)


@requires_sklearn
Expand Down
32 changes: 32 additions & 0 deletions mne/decoding/tests/test_receptive_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,3 +596,35 @@ def test_linalg_warning():
(RuntimeWarning, UserWarning), match="[Singular|scipy.linalg.solve]"
):
rf.fit(y, X)


@requires_sklearn
def test_tdr_sklearn_compliance():
"""Test sklearn estimator compliance."""
from sklearn.utils.estimator_checks import check_estimator

tdr = TimeDelayingRidge(0, 10, 1.0, 0.1, "laplacian", n_jobs=1)
# We don't actually comply with a bunch of the regressor specs :(
ignores = (
"check_supervised_y_no_nan",
"check_regressor",
"check_parameters_default_constructible",
"check_estimators_unfitted",
"_invariance",
"check_fit2d_1sample",
)
for est, check in check_estimator(tdr, generate_only=True):
if any(ignore in str(check) for ignore in ignores):
continue
check(est)


@requires_sklearn
def test_rf_sklearn_compliance():
"""Test sklearn RF compliance."""
from sklearn.linear_model import Ridge
from sklearn.utils.estimator_checks import check_estimator

rf = ReceptiveField(-0.1, 0.2, 1000.0, estimator=Ridge(), patterns=True)
for est, check in check_estimator(rf, generate_only=True):
check(est)
18 changes: 15 additions & 3 deletions mne/decoding/time_delaying_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..cuda import _setup_cuda_fft_multiply_repeated
from ..filter import next_fast_len
from ..fixes import jit
from ..utils import warn, ProgressBar, logger
from ..utils import warn, ProgressBar, logger, _validate_type, _check_option


def _compute_corrs(
Expand Down Expand Up @@ -301,6 +301,9 @@ def __init__(
self.edge_correction = edge_correction
self.n_jobs = n_jobs

def _more_tags(self):
return {"no_validation": True}

@property
def _smin(self):
return int(round(self.tmin * self.sfreq))
Expand All @@ -324,12 +327,21 @@ def fit(self, X, y):
self : instance of TimeDelayingRidge
Returns the modified instance.
"""
_validate_type(X, "array-like", "X")
_validate_type(y, "array-like", "y")
X = np.asarray(X, dtype=float)
y = np.asarray(y, dtype=float)
if X.ndim == 3:
assert y.ndim == 3
assert X.shape[:2] == y.shape[:2]
else:
assert X.ndim == 2 and y.ndim == 2
assert X.shape[0] == y.shape[0]
if X.ndim == 1:
X = X[:, np.newaxis]
if y.ndim == 1:
y = y[:, np.newaxis]
assert X.ndim == 2
assert y.ndim == 2
_check_option("y.shape[0]", y.shape[0], (X.shape[0],))
# These are split into two functions because it's possible that we
# might want to allow people to do them separately (e.g., to test
# different regularization parameters).
Expand Down
32 changes: 0 additions & 32 deletions mne/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,38 +176,6 @@ def _read_volume_info(fobj):
# adapted from scikit-learn


def is_classifier(estimator):
"""Returns True if the given estimator is (probably) a classifier.
Parameters
----------
estimator : object
Estimator object to test.
Returns
-------
out : bool
True if estimator is a classifier and False otherwise.
"""
return getattr(estimator, "_estimator_type", None) == "classifier"


def is_regressor(estimator):
"""Returns True if the given estimator is (probably) a regressor.
Parameters
----------
estimator : object
Estimator object to test.
Returns
-------
out : bool
True if estimator is a regressor and False otherwise.
"""
return getattr(estimator, "_estimator_type", None) == "regressor"


_DEFAULT_TAGS = {
"non_deterministic": False,
"requires_positive_X": False,
Expand Down

0 comments on commit 93b1604

Please sign in to comment.