Skip to content

Commit

Permalink
MAINT: Improve sklearn compliance
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Jul 31, 2023
1 parent a560999 commit b371e4b
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 158 deletions.
132 changes: 28 additions & 104 deletions mne/decoding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import datetime as dt
import numbers
from ..parallel import parallel_func
from ..fixes import BaseEstimator, is_classifier, _get_check_scoring
from ..fixes import BaseEstimator, _get_check_scoring
from ..utils import warn, verbose


Expand Down Expand Up @@ -51,14 +51,37 @@ class LinearModel(BaseEstimator):
.. footbibliography::
"""

_model_attr_wrap = (
"transform",
"predict",
"predict_proba",
"_estimator_type",
"decision_function",
"score",
"classes_",
)

def __init__(self, model=None): # noqa: D102
if model is None:
from sklearn.linear_model import LogisticRegression

model = LogisticRegression(solver="liblinear")

self.model = model
self._estimator_type = getattr(model, "_estimator_type", None)

def _more_tags(self):
return {"no_validation": True}

def __getattr__(self, attr):
"""Wrap to model for some attributes."""
if attr in LinearModel._model_attr_wrap:
return getattr(self.model, attr)
elif attr == "fit_transform" and hasattr(self.model, "fit_transform"):
return super().__getattr__(self, "_fit_transform")
return super().__getattr__(self, attr)

def _fit_transform(self, X, y):
return self.fit(X, y).transform(X)

def fit(self, X, y, **fit_params):
"""Estimate the coefficients of the linear model.
Expand Down Expand Up @@ -120,110 +143,12 @@ def filters_(self):
filters = filters[0]
return filters

def transform(self, X):
"""Transform the data using the linear model.
Parameters
----------
X : array, shape (n_samples, n_features)
The data to transform.
Returns
-------
y_pred : array, shape (n_samples,)
The predicted targets.
"""
return self.model.transform(X)

def fit_transform(self, X, y):
"""Fit the data and transform it using the linear model.
Parameters
----------
X : array, shape (n_samples, n_features)
The training input samples to estimate the linear coefficients.
y : array, shape (n_samples,)
The target values.
Returns
-------
y_pred : array, shape (n_samples,)
The predicted targets.
"""
return self.fit(X, y).transform(X)

def predict(self, X):
"""Compute predictions of y from X.
Parameters
----------
X : array, shape (n_samples, n_features)
The data used to compute the predictions.
Returns
-------
y_pred : array, shape (n_samples,)
The predictions.
"""
return self.model.predict(X)

def predict_proba(self, X):
"""Compute probabilistic predictions of y from X.
Parameters
----------
X : array, shape (n_samples, n_features)
The data used to compute the predictions.
Returns
-------
y_pred : array, shape (n_samples, n_classes)
The probabilities.
"""
return self.model.predict_proba(X)

def decision_function(self, X):
"""Compute distance from the decision function of y from X.
Parameters
----------
X : array, shape (n_samples, n_features)
The data used to compute the predictions.
Returns
-------
y_pred : array, shape (n_samples, n_classes)
The distances.
"""
return self.model.decision_function(X)

def score(self, X, y):
"""Score the linear model computed on the given test data.
Parameters
----------
X : array, shape (n_samples, n_features)
The data to transform.
y : array, shape (n_samples,)
The target values.
Returns
-------
score : float
Score of the linear model.
"""
return self.model.score(X, y)

# Needed for sklearn 1.3+
@property
def classes_(self):
"""The classes (pass-through to model)."""
return self.model.classes_


def _set_cv(cv, estimator=None, X=None, y=None):
"""Set the default CV depending on whether clf is classifier/regressor."""
# Detect whether classification or regression
from sklearn.base import is_classifier

if estimator in ["classifier", "regressor"]:
est_is_classifier = estimator == "classifier"
else:
Expand Down Expand Up @@ -440,8 +365,7 @@ def cross_val_multiscore(
Array of scores of the estimator for each run of the cross validation.
"""
# This code is copied from sklearn

from sklearn.base import clone
from sklearn.base import clone, is_classifier
from sklearn.utils import indexable
from sklearn.model_selection._split import check_cv

Expand Down
22 changes: 16 additions & 6 deletions mne/decoding/receptive_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from .base import get_coef, BaseEstimator, _check_estimator
from .time_delaying_ridge import TimeDelayingRidge
from ..fixes import is_regressor
from ..utils import _validate_type, verbose, fill_doc


Expand Down Expand Up @@ -128,6 +127,9 @@ def __init__(
self.n_jobs = n_jobs
self.edge_correction = edge_correction

def _more_tags(self):
return {"no_validation": True}

def __repr__(self): # noqa: D105
s = "tmin, tmax : (%.3f, %.3f), " % (self.tmin, self.tmax)
estimator = self.estimator
Expand All @@ -153,7 +155,11 @@ def _delay_and_reshape(self, X, y=None):
if not isinstance(self.estimator_, TimeDelayingRidge):
# X is now shape (n_times, n_epochs, n_feats, n_delays)
X = _delay_time_series(
X, self.tmin, self.tmax, self.sfreq, fill_mean=self.fit_intercept
X,
self.tmin,
self.tmax,
self.sfreq,
fill_mean=self.fit_intercept_,
)
X = _reshape_for_est(X)
# Concat times + epochs
Expand Down Expand Up @@ -183,7 +189,7 @@ def fit(self, X, y):
"scoring must be one of %s, got"
"%s " % (sorted(_SCORERS.keys()), self.scoring)
)
from sklearn.base import clone
from sklearn.base import clone, is_regressor

X, y, _, self._y_dim = self._check_dimensions(X, y)

Expand All @@ -199,13 +205,15 @@ def fit(self, X, y):

if isinstance(self.estimator, numbers.Real):
if self.fit_intercept is None:
self.fit_intercept = True
self.fit_intercept_ = True
else:
self.fit_intercept_ = self.fit_intercept
estimator = TimeDelayingRidge(
self.tmin,
self.tmax,
self.sfreq,
alpha=self.estimator,
fit_intercept=self.fit_intercept,
fit_intercept=self.fit_intercept_,
n_jobs=self.n_jobs,
edge_correction=self.edge_correction,
)
Expand All @@ -221,7 +229,7 @@ def fit(self, X, y):
"same fit_intercept value or use fit_intercept=None"
% (estimator.fit_intercept, self.fit_intercept)
)
self.fit_intercept = estimator.fit_intercept
self.fit_intercept_ = estimator.fit_intercept
else:
raise ValueError(
"`estimator` must be a float or an instance"
Expand Down Expand Up @@ -354,6 +362,8 @@ def score(self, X, y):
return scores

def _check_dimensions(self, X, y, predict=False):
_validate_type(X, "array-like", "X")
_validate_type(y, ("array-like", None), "y")
X_dim = X.ndim
y_dim = y.ndim if y is not None else 0
if X_dim == 2:
Expand Down
46 changes: 36 additions & 10 deletions mne/decoding/search_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class SlidingEstimator(BaseEstimator, TransformerMixin):
%(scoring)s
%(n_jobs)s
%(position)s
allow_2d : bool
If True, allow 2D data as input (i.e. n_samples, n_features).
%(verbose)s
Attributes
Expand All @@ -35,16 +37,30 @@ class SlidingEstimator(BaseEstimator, TransformerMixin):

@verbose
def __init__(
self, base_estimator, scoring=None, n_jobs=None, *, position=0, verbose=None
self,
base_estimator,
scoring=None,
n_jobs=None,
*,
position=0,
allow_2d=False,
verbose=None,
): # noqa: D102
_check_estimator(base_estimator)
self._estimator_type = getattr(base_estimator, "_estimator_type", None)
self.base_estimator = base_estimator
self.n_jobs = n_jobs
self.scoring = scoring
self.position = position
self.allow_2d = allow_2d
self.verbose = verbose

def _more_tags(self):
return {"no_validation": True, "requires_fit": False}

@property
def _estimator_type(self):
return getattr(self.base_estimator, "_estimator_type", None)

def __repr__(self): # noqa: D105
repr_str = "<" + super(SlidingEstimator, self).__repr__()
if hasattr(self, "estimators_"):
Expand Down Expand Up @@ -72,12 +88,12 @@ def fit(self, X, y, **fit_params):
self : object
Return self.
"""
self._check_Xy(X, y)
X = self._check_Xy(X, y)
parallel, p_func, n_jobs = parallel_func(
_sl_fit, self.n_jobs, max_jobs=X.shape[-1], verbose=False
)
self.estimators_ = list()
self.fit_params = fit_params
self.fit_params_ = fit_params

# For fitting, the parallelization is across estimators.
context = _create_progressbar_context(self, X, "Fitting")
Expand Down Expand Up @@ -123,7 +139,7 @@ def fit_transform(self, X, y, **fit_params):

def _transform(self, X, method):
"""Aux. function to make parallel predictions/transformation."""
self._check_Xy(X)
X = self._check_Xy(X)
method = _check_method(self.base_estimator, method)
if X.shape[-1] != len(self.estimators_):
raise ValueError("The number of estimators does not match " "X.shape[-1]")
Expand All @@ -144,7 +160,7 @@ def _transform(self, X, method):
)

y_pred = np.concatenate(y_pred, axis=1)
return y_pred
return y_pred.astype(X.dtype)

def transform(self, X):
"""Transform each data slice/task with a series of independent estimators.
Expand Down Expand Up @@ -237,11 +253,21 @@ def decision_function(self, X):

def _check_Xy(self, X, y=None):
"""Aux. function to check input data."""
X = np.asarray(X)
if y is not None:
y = np.asarray(y)
if len(X) != len(y) or len(y) < 1:
raise ValueError("X and y must have the same length.")
if X.ndim < 3:
raise ValueError("X must have at least 3 dimensions.")
err = None
if not self.allow_2d:
err = 3
elif X.ndim < 2:
err = 2
if err:
raise ValueError(f"X must have at least {err} dimensions.")
X = X[..., np.newaxis]
return X

def score(self, X, y):
"""Score each estimator on each task.
Expand All @@ -268,7 +294,7 @@ def score(self, X, y):
""" # noqa: E501
check_scoring = _get_check_scoring()

self._check_Xy(X)
X = self._check_Xy(X, y)
if X.shape[-1] != len(self.estimators_):
raise ValueError("The number of estimators does not match " "X.shape[-1]")

Expand Down Expand Up @@ -446,7 +472,7 @@ def __repr__(self): # noqa: D105

def _transform(self, X, method):
"""Aux. function to make parallel predictions/transformation."""
self._check_Xy(X)
X = self._check_Xy(X)
method = _check_method(self.base_estimator, method)

parallel, p_func, n_jobs = parallel_func(
Expand Down Expand Up @@ -567,7 +593,7 @@ def score(self, X, y):
Score for each estimator / data slice couple.
""" # noqa: E501
check_scoring = _get_check_scoring()
self._check_Xy(X)
X = self._check_Xy(X, y)
# For predictions/transforms the parallelization is across the data and
# not across the estimators to avoid memory load.
parallel, p_func, n_jobs = parallel_func(
Expand Down
Loading

0 comments on commit b371e4b

Please sign in to comment.