diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index c771d394c..c045599ab 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -178,11 +178,98 @@ This is basically the same as calling :py:func:`~scikeras.wrappers.BaseWrapper.g Data Transformers ^^^^^^^^^^^^^^^^^ -In some cases, the input actually consists of multiple inputs. E.g., +Keras supports a much wider range of inputs/outputs than Scikit-Learn does. E.g., in a text classification task, you might have an array that contains the integers representing the tokens for each sample, and another -array containing the number of tokens of each sample. SciKeras has you -covered here as well. +array containing the number of tokens of each sample. + +In order to reconcile Keras' expanded input/output support and Scikit-Learn's more +limited options, SciKeras introduces "data transformers". These are really just +dependency injection points where you can declare custom data transformations, +for example to split an array into a list of arrays, join ``X`` & ``y`` into a ``Dataset``, etc. +In order to keep these transformations in a familiar format, they are implemented as +sklearn-style transformers. You can think of this setup as an sklearn Pipeline: + +.. code-block:: + + ↗ feature_encoder ↘ + SciKeras.fit(features, labels) dataset_transformer → keras.Model.fit(data) + ↘ target_encoder ↗ + + +Within SciKeras, this is roughly implemented as follows: + +.. code:: python + + class PseudoBaseWrapper: + + def fit(self, X, y, sample_weight): + self.target_encoder_ = self.target_encoder.fit(X) + X = self.feature_encoder_.transform(X) + self.feature_encoder_ = self.feature_encoder.fit(y) + y = self.target_encoder_.transform(y) + self.model_ = self._build_keras_model() + fit_kwargs = dict(x=X, y=y, sample_weight=sample_weight) + self.dataset_transformer_ = self.dataset_transformer.fit(fit_kwargs) + fit_kwargs = self.dataset_transformer_.transform(fit_kwargs) + self.model_.fit(x=X, y=y, sample_weight=sample_weight) # tf.keras.Model.fit + return self + + def predict(self, X): + X = self.feature_encoder_.transform(X) + predict_kwargs = dict(x=X) + predict_kwargs = self.dataset_transformer_.fit_transform(predict_kwargs) + y_pred = self.model_.predict(**predict_kwargs) + return self.target_encoder_.inverse_transform(y_pred) + + +``dataset_transformer`` is the last step before passing the data to Keras, and it allows for the greatest +degree of customization because SciKeras does not make any assumptions about the output data +and passes it directly to :py:func:`tensorflow.keras.Model.fit`. + +It accepts a dict of valid Keras ``**kwargs`` and is expected to return a dict +of valid Keras ``**kwargs``: + +.. code:: python + + from sklearn.base import BaseEstimator, TransformerMixin + + class DatasetTransformer(BaseEstimator, TransformerMixin): + def fit(self, data: Dict[str, Any]) -> "DatasetTransformer": + assert data.keys() == {"x", "y", "sample_weight"} # fixed keys + ... + return self + + def transform(self, data): # return a valid input for keras.Model.fit + # data includes x, y, sample_weight + assert "x" in data # "x" is always a keys + if "y" in data: + # called from fit + else: + # called from predict + # as well as other Model.fit or Model.predict arguments + assert "batch_size" in data + ... + return data + + +You can modify ``data`` in-place within ``transoform`` but **must** still return +it. + +When called from ``fit`` or ``initialize``, you will get and return keys that are valid +``**kwargs`` to ``tf.keras.Model.fit``. When being called from ``predict`` or ``score`` +you will get and return keys that are valid ``**kwargs`` to ``tf.keras.Model.predict``. + +Although you could implement *all* data transformations in a single ``dataset_transformer``, +having several distinct dependency injections points allows for more modularity, +for example to keep the default processing of string-encoded labels but convert +the data to a :py:func:`tensorflow.data.Dataset` before passing to Keras. + +For a complete examples implementing custom data processing, see the examples in the +:ref:`tutorials` section. + +Multi-input and output models via feature_encoder and target_encoder +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ Scikit-Learn natively supports multiple outputs, although it technically requires them to be arrays of equal length @@ -190,14 +277,15 @@ requires them to be arrays of equal length Scikit-Learn has no support for multiple inputs. To work around this issue, SciKeras implements a data conversion abstraction in the form of Scikit-Learn style transformers, -one for ``X`` (features) and one for ``y`` (target). -By implementing a custom transformer, you can split a single input ``X`` into multiple inputs -for :py:class:`tensorflow.keras.Model` or perform any other manipulation you need. +one for ``X`` (features) and one for ``y`` (target). These are implemented +via :py:func:`scikeras.wrappers.BaseWrappers.feature_encoder` and +:py:func:`scikeras.wrappers.BaseWrappers.feature_encoder` respectively. + To override the default transformers, simply override :py:func:`scikeras.wrappers.BaseWrappers.target_encoder` or :py:func:`scikeras.wrappers.BaseWrappers.feature_encoder` for ``y`` and ``X`` respectively. -SciKeras uses :py:func:`sklearn.utils.multiclass.type_of_target` to categorize the target +By default, SciKeras uses :py:func:`sklearn.utils.multiclass.type_of_target` to categorize the target type, and implements basic handling of the following cases out of the box: +--------------------------+--------------+----------------+----------------+---------------+ @@ -208,11 +296,11 @@ type, and implements basic handling of the following cases out of the box: +--------------------------+--------------+----------------+----------------+---------------+ | "binary" | [1, 0, 1] | 1 | 1 or 2 | Yes | +--------------------------+--------------+----------------+----------------+---------------+ -| "mulilabel-indicator" | [[1, 1], | 1 or >1 | 2 per target | Single output | +| "multilabel-indicator" | [[1, 1], | 1 or >1 | 2 per target | Single output | | | | | | | -| | [0, 2], | | | only | +| | [0, 1], | | | only | | | | | | | -| | [1, 1]] | | | | +| | [1, 0]] | | | | +--------------------------+--------------+----------------+----------------+---------------+ | "multiclass-multioutput" | [[1, 1], | >1 | >=2 per target | No | | | | | | | @@ -229,11 +317,47 @@ type, and implements basic handling of the following cases out of the box: | | [.2, .9]] | | | | +--------------------------+--------------+----------------+----------------+---------------+ -If you find that your target is classified as ``"multiclass-multioutput"`` or ``"unknown"``, you will have to -implement your own data processing routine. +The supported cases are handled by the default implementation of ``target_encoder``. +The default implementations are available for use as :py:class:`scikeras.utils.transformers.ClassifierLabelEncoder` +and :py:class:`scikeras.utils.transformers.RegressorTargetEncoder` for +:py:class:`scikeras.wrappers.KerasClassifier` and :py:class:`scikeras.wrappers.KerasRegressor` respectively. -For a complete examples implementing custom data processing, see the examples in the -:ref:`tutorials` section. +As per the table above, if you find that your target is classified as +``"multiclass-multioutput"`` or ``"unknown"``, you will have to implement your own data processing routine. + +get_metadata method ++++++++++++++++++++ + +In addition to converting data, ``feature_encoder`` and ``target_encoder``, allows you to inject data +into your model construction method. This is useful if for example you use ``target_encoder`` to dynamically +determine how many outputs your model should have based on the data and then use this information to +assign the right number of outputs in your Model. To return data from ``feature_encoder`` or ``target_encoder``, +you will need to provide a transformer with a ``get_metadata`` method, which is expected to return a dictionary +which will be injected into your model building function via the ``meta`` parameter. + +For example, if you wanted to create a calculated parameter called ``my_param_``: + +.. code-block::python + + class MultiOutputTransformer(BaseEstimator, TransformerMixin): + def get_metadata(self): + return {"my_param_": "foobarbaz"} + + class MultiOutputClassifier(KerasClassifier): + + @property + def target_encoder(self): + return MultiOutputTransformer(...) + + def get_model(meta): + print(f"Got: {meta['my_param_']}") + + clf = MultiOutputClassifier(model=get_model) + clf.fit(X, y) # prints 'Got: foobarbaz' + print(clf.my_param_) # prints 'foobarbaz' + +Note that it is best practice to end your parameter names with a single underscore, +which allows sklearn to know which parameters are stateful and which are stateless. Routed parameters ----------------- @@ -282,6 +406,8 @@ Custom Scorers SciKeras uses :func:`sklearn.metrics.accuracy_score` and :func:`sklearn.metrics.accuracy_score` as the scoring functions for :class:`scikeras.wrappers.KerasClassifier` and :class:`scikeras.wrappers.KerasRegressor` respectively. To override these scoring functions, +override :func:`scikeras.wrappers.KerasClassifier.scorer` +or :func:`scikeras.wrappers.KerasRegressor.scorer`. .. _Keras Callbacks docs: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks diff --git a/docs/source/notebooks/DataTransformers.md b/docs/source/notebooks/DataTransformers.md index 1cb66bdd5..da3d51499 100644 --- a/docs/source/notebooks/DataTransformers.md +++ b/docs/source/notebooks/DataTransformers.md @@ -5,8 +5,8 @@ jupyter: text_representation: extension: .md format_name: markdown - format_version: '1.2' - jupytext_version: 1.9.1 + format_version: '1.3' + jupytext_version: 1.10.1 kernelspec: display_name: Python 3 language: python @@ -25,24 +25,29 @@ Keras support many types of input and output data formats, including: * Multiple outputs * Higher-dimensional tensors -In this notebook, we explore how to reconcile this functionality with the sklearn ecosystem via SciKeras data transformer interface. +This notebook walks through an example of the different data transformations and how SciKeras bridges Keras and Scikit-learn. +It may be helpful to have a general understanding of the dataflow before tackling these examples, which is available in +the [data transformer docs](https://www.adriangb.com/scikeras/refs/heads/master/advanced.html#data-transformers). ## Table of contents * [1. Setup](#1.-Setup) -* [2. Data transformer interface](#2.-Data-transformer-interface) - * [2.1 get_metadata method](#2.1-get_metadata-method) -* [3. Multiple outputs](#3.-Multiple-outputs) +* [2. Multiple outputs](#2.-Multiple-outputs) + * [2.1 Define Keras Model](#2.1-Define-Keras-Model) + * [2.2 Define output data transformer](#2.2-Define-output-data-transformer) + * [2.3 Test classifier](#2.3-Test-classifier) +* [3. Multiple inputs](#3-multiple-inputs) * [3.1 Define Keras Model](#3.1-Define-Keras-Model) - * [3.2 Define output data transformer](#3.2-Define-output-data-transformer) - * [3.3 Test classifier](#3.3-Test-classifier) -* [4. Multiple inputs](#4-multiple-inputs) + * [3.2 Define data transformer](#3.2-Define-data-transformer) + * [3.3 Test regressor](#3.3-Test-regressor) +* [4. Multidimensional inputs with MNIST dataset](#4.-Multidimensional-inputs-with-MNIST-dataset) * [4.1 Define Keras Model](#4.1-Define-Keras-Model) - * [4.2 Define data transformer](#4.2-Define-data-transformer) - * [4.3 Test regressor](#4.3-Test-regressor) -* [5. Multidimensional inputs with MNIST dataset](#5.-Multidimensional-inputs-with-MNIST-dataset) - * [5.1 Define Keras Model](#5.1-Define-Keras-Model) - * [5.2 Test](#5.2-Test) + * [4.2 Test](#4.2-Test) +* [5. Ragged datasets with tf.data.Dataset](#5.-Ragged-datasets-with-tf.data.Dataset) +* [6. Multi-output class_weight](#6.-Multi-output-class_weight) +* [7. Custom validation dataset](#6.-Custom-validation-dataset) +* [8. Dynamically setting batch_size](#6.-Dynamically-setting-batch_size) + ## 1. Setup @@ -68,6 +73,9 @@ from scikeras.wrappers import KerasClassifier, KerasRegressor from tensorflow import keras ``` +<<<<<<< HEAD +## 2. Multiple outputs +======= ## 2. Data transformer interface SciKeras enables advanced Keras use cases by providing an interface to convert sklearn compliant data to whatever format your Keras model requires within SciKeras, right before passing said data to the Keras model. @@ -168,6 +176,7 @@ if False: # avoid executing pseudocode ``` ## 3. Multiple outputs +>>>>>>> master Keras makes it straight forward to define models with multiple outputs, that is a Model with multiple sets of fully-connected heads at the end of the network. This functionality is only available in the Functional Model and subclassed Model definition modes, and is not available when using Sequential. @@ -175,7 +184,7 @@ In practice, the main thing about Keras models with multiple outputs that you ne Note that "multiple outputs" in Keras has a slightly different meaning than "multiple outputs" in sklearn. Many tasks that would be considered "multiple output" tasks in sklearn can be mapped to a single "output" in Keras with multiple units. This notebook specifically focuses on the cases that require multiple distinct Keras outputs. -### 3.1 Define Keras Model +### 2.1 Define Keras Model Here we define a simple perceptron that has two outputs, corresponding to one binary classification taks and one multiclass classification task. For example, one output might be "image has car" (binary) and the other might be "color of car in image" (multiclass). @@ -227,7 +236,7 @@ Our data transormer's job will be to convert from a single numpy array (which is We will structure our data on the sklearn side by column-stacking our list of arrays. This works well in this case since we have the same number of datapoints in each array. -### 3.2 Define output data transformer +### 2.2 Define output data transformer Let's go ahead and protoype this data transformer: @@ -287,7 +296,7 @@ class MultiOutputTransformer(BaseEstimator, TransformerMixin): Note that in addition to the usual `transform` and `inverse_transform` methods, we implement the `get_metadata` method to return the `n_classes_` attribute. -Lets test our transformer with the same dataset we previoulsy used to test our model: +Lets test our transformer with the same dataset we previously used to test our model: ```python tf = MultiOutputTransformer() @@ -329,7 +338,7 @@ class MultiOutputClassifier(KerasClassifier): return np.mean([accuracy_score(y_bin, y_pred_bin), accuracy_score(y_cat, y_pred_cat)]) ``` -### 3.3 Test classifier +### 2.3 Test classifier ```python from sklearn.preprocessing import StandardScaler @@ -345,27 +354,22 @@ clf = MultiOutputClassifier(model=get_clf_model, verbose=0, random_state=0) clf.fit(X, y_sklearn).score(X, y_sklearn) ``` -## 4. Multiple inputs +## 3. Multiple inputs The process for multiple inputs is similar, but instead of overriding the transformer in `target_encoder` we override `feature_encoder`. -```python -if False: - from sklearn.base import BaseEstimator, TransformerMixin - - - class MultiInputTransformer(BaseEstimator, TransformerMixin): - ... - - class MultiInputClassifier(KerasClassifier): +```python .noeval +class MultiInputTransformer(BaseEstimator, TransformerMixin): + ... - @property - def feature_encoder(self): - return MultiInputTransformer(...) +class MultiInputClassifier(KerasClassifier): + @property + def feature_encoder(self): + return MultiInputTransformer(...) ``` -### 4.1 Define Keras Model +### 3.1 Define Keras Model Let's define a Keras **regression** Model with 2 inputs: @@ -409,7 +413,7 @@ r2_score(y, y_pred) Having verified that our model builds without errors and accepts the inputs types we expect, we move onto integrating a transformer into our SciKeras model. -### 4.2 Define data transformer +### 3.2 Define data transformer Just like for overriding `target_encoder`, we just need to define a sklearn transformer and drop it into our SciKeras wrapper. Since we hardcoded the input shapes into our model and do not rely on any transformer-generated metadata, we can simply use `sklearn.preprocessing.FunctionTransformer`: @@ -429,7 +433,7 @@ class MultiInputRegressor(KerasRegressor): Note that we did **not** implement `inverse_transform` (that is, we did not pass an `inverse_func` argument to `FunctionTransformer`) because features are never converted back to their original form. -### 4.3 Test regressor +### 3.3 Test regressor ```python reg = MultiInputRegressor(model=get_reg_model, verbose=0, random_state=0) @@ -439,7 +443,7 @@ X_sklearn = np.column_stack(X) reg.fit(X_sklearn, y).score(X_sklearn, y) ``` -## 5. Multidimensional inputs with MNIST dataset +## 4. Multidimensional inputs with MNIST dataset In this example, we look at how we can use SciKeras to process the MNIST dataset. The dataset is composed of 60,000 images of digits, each of which is a 2D 28x28 image. @@ -469,6 +473,10 @@ x_train = x_train.reshape((n_samples_train, -1)) x_test = x_test.reshape((n_samples_test, -1)) x_train = MinMaxScaler().fit_transform(x_train) x_test = MinMaxScaler().fit_transform(x_test) + +# reduce dataset size for faster training +n_samples = 1000 +x_train, y_train, x_test, y_test = x_train[:n_samples], y_train[:n_samples], x_test[:n_samples], y_test[:n_samples] ``` ```python @@ -481,7 +489,7 @@ print(np.min(x_train), np.max(x_train)) # scaled 0-1 Of course, in this case, we could have just as easily used numpy functions to scale our data, but we use `MinMaxScaler` to demonstrate use of the sklearn ecosystem. -### 5.1 Define Keras Model +### 4.1 Define Keras Model Next we will define our Keras model (adapted from [keras.io](https://keras.io/examples/vision/mnist_convnet/)): @@ -531,15 +539,491 @@ clf = MultiDimensionalClassifier( ) ``` -### 5.2 Test +### 4.2 Test Train and score the model (this takes some time) ```python -clf.fit(x_train, y_train) +_ = clf.fit(x_train, y_train) ``` ```python score = clf.score(x_test, y_test) print(f"Test score (accuracy): {score:.2f}") ``` + +## 5. Ragged datasets with tf.data.Dataset + +SciKeras provides a third dependency injection point that operates on the entire dataset: X, y & sample_weight. +This `dataset_transformer` is applied after `target_transformer` and `feature_transformer`. +One use case for this dependency injection point is to transform data from tabular/array-like to the `tf.data.Dataset` format, which only requires iteration. +We can use this to create a `tf.data.Dataset` of ragged tensors. + +Note that `dataset_transformer` should accept a single single dictionary as its argument to `transform` and `fit`, and return a single dictionary as well. +More details on this are in the [docs](https://www.adriangb.com/scikeras/refs/heads/master/advanced.html#data-transformers). + +Let's start by defining our data. We'll have an extra "feature" that marks the observation index, but we'll remove it when we deconstruct our data in the transformer. + +```python +feature_1 = np.random.uniform(size=(10, )) +feature_2 = np.random.uniform(size=(10, )) +obs = [0, 0, 0, 1, 1, 2, 3, 3, 4, 4] + +X = np.column_stack([feature_1, feature_2, obs]).astype("float32") + +y = np.array(["class1"] * 5 + ["class2"] * 5, dtype=str) +``` + +Next, we define our `dataset_transformer`. We will do this by defining a custom forward transformation outside of the Keras model. Note that we do not define an inverse transformation since that is never used. +Also note that `dataset_transformer` will _always_ be called with `X` (i.e. the first element of the tuple will always be populated), but will be called with `y=None` when used for `predict`. Thus, +you should check if `y` and `sample_weigh` are None before doing any operations on them. + +```python +from typing import Dict, Any + +import tensorflow as tf + + +def ragged_transformer(data: Dict[str, Any]) -> Dict[str, Any]: + x, y, sample_weight = data["x"], data.get("y", None), data.get("sample_weight", None) + if y is not None: + y = y.reshape(-1, 1 if len(y.shape) == 1 else y.shape[1]) + y = y[tf.RaggedTensor.from_value_rowids(y, x[:, -1]).row_starts().numpy()] + if sample_weight is not None: + sample_weight = sample_weight.reshape(-1, 1 if len(sample_weight.shape) == 1 else sample_weight.shape[1]) + sample_weight = sample_weight[tf.RaggedTensor.from_value_rowids(sample_weight, x[:, -1]).row_starts().numpy()] + x = tf.RaggedTensor.from_value_rowids(x[:, :-1], x[:, -1]) + data["x"] = x + if "y" in data: + data["y"] = y + if "sample_weight" in data: + data["sample_weight"] = sample_weight + return data +``` + +In this case, we chose to keep `y` and `sample_weight` as numpy arrays, which will allow us to re-use ClassWeightDataTransformer, +the default `dataset_transformer` for `KerasClassifier`. + +Lets quickly test our transformer: + +```python +data = ragged_transformer(dict(x=X, y=y, sample_weight=None)) +print(type(data["x"])) +print(data["x"].shape) +``` + +And the `y=None` case: + +```python +data = ragged_transformer(dict(x=X, y=None, sample_weight=None)) +print(type(data["x"])) +print(data["x"].shape) +``` + +Everything looks good! + +Because Keras will not accept a RaggedTensor directly, we will need to wrap our entire dataset into a tensorflow `Dataset`. We can do this by adding one more transformation step: + +Next, we can add our transormers to our model. We use an sklearn `Pipeline` (generated via `make_pipeline`) to keep ClassWeightDataTransformer operational while implementing our custom transformation. + +```python +def dataset_transformer(data: Dict[str, Any]) -> Dict[str, Any]: + x_y_s = data["x"], data.get("y", None), data.get("sample_weight", None) + data["x"] = tf.data.Dataset.from_tensor_slices(x_y_s) + # don't blindly assign y & sw; if being called from + # predict they should not just be None, they should not be present at all! + if "y" in data: + data["y"] = None + if "sample_weight" in data: + data["sample_weight"] = None + return data +``` + +```python +from sklearn.preprocessing import FunctionTransformer +from sklearn.pipeline import make_pipeline + + +class RaggedClassifier(KerasClassifier): + + @property + def dataset_transformer(self): + t1 = FunctionTransformer(ragged_transformer) + t2 = super().dataset_transformer # ClassWeightDataTransformer + t3 = FunctionTransformer(dataset_transformer) + t4 = "passthrough" # see https://scikit-learn.org/stable/modules/compose.html#pipeline-chaining-estimators + return make_pipeline(t1, t2, t3, t4) +``` + +Now we can define a Model. We need some way to handle/flatten our ragged arrays within our model. For this example, we use a custom mean layer, but you could use an Embedding layer, LSTM, etc. + +```python +from tensorflow import reduce_mean, reshape +from tensorflow.keras import Sequential, layers + + +class CustomMean(layers.Layer): + + def __init__(self, axis=None): + super(CustomMean, self).__init__() + self._supports_ragged_inputs = True + self.axis = axis + + def call(self, inputs, **kwargs): + input_shape = inputs.get_shape() + return reshape(reduce_mean(inputs, axis=self.axis), (1, *input_shape[1:])) + + +def get_model(meta): + inp_shape = meta["X_shape_"][1]-1 + model = Sequential([ + layers.Input(shape=(inp_shape,), ragged=True), + CustomMean(axis=0), + layers.Dense(1, activation='sigmoid') + ]) + return model +``` + +And attach our model to our classifier wrapper: + +```python +clf = RaggedClassifier(get_model, loss="bce") +``` + +Finally, let's train and predict: + +```python +clf.fit(X, y) +y_pred = clf.predict(X) +y_pred +``` + +If we define our custom layers, transformers and wrappers in their own module, we can easily create a self-contained classifier that is able to handle ragged datasets and has a clean Scikit-Learn compatible API: + +```python +class RaggedClassifier(KerasClassifier): + + @property + def dataset_transformer(self): + t1 = FunctionTransformer(ragged_transformer) + t2 = super().dataset_transformer # ClassWeightDataTransformer + t3 = FunctionTransformer(dataset_transformer) + t4 = "passthrough" # see https://scikit-learn.org/stable/modules/compose.html#pipeline-chaining-estimators + return make_pipeline(t1, t2, t3, t4) + + def _keras_build_fn(self): + inp_shape = self.X_shape_[1] - 1 + model = Sequential([ + layers.Input(shape=(inp_shape,), ragged=True), + CustomMean(axis=0), + layers.Dense(1, activation='sigmoid') + ]) + return model +``` + +```python +clf = RaggedClassifier(loss="bce") +clf.fit(X, y) +y_pred = clf.predict(X) +y_pred +``` + +## 6. Multi-output class_weight + +In this example, we will use `dataset_transformer` to support multi-output class weights. +We will re-use our `MultiOutputTransformer` from our previous example to split the output, then we will create `sample_weight` from `class_weight`. + +```python +from collections import defaultdict +from typing import Union + +from sklearn.utils.class_weight import compute_sample_weight + + +class DatasetTransformer(BaseEstimator, TransformerMixin): + + def __init__(self, output_names): + self.output_names = output_names + + def fit(self, data: Dict[str, Any]) -> "DatasetTransformer": + return self + + def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: + class_weight = data.get("class_weight", None) + if class_weight is None: + return data + if isinstance(class_weight, str): # handle "balanced" + class_weight_ = class_weight + class_weight = defaultdict(lambda: class_weight_) + y, sample_weight = data.get("y", None), data.get("sample_weight", None) + assert sample_weight is None, "Cannot use class_weight & sample_weight together" + if y is not None: + # y should be a list of arrays, as split up by MultiOutputTransformer + sample_weight = { + output_name: compute_sample_weight(class_weight[output_num], output_data) + for output_num, (output_name, output_data) in enumerate(zip(self.output_names, y)) + } + # Note: class_weight is expected to be indexable by output_number in sklearn + # see https://scikit-learn.org/stable/modules/generated/sklearn.utils.class_weight.compute_sample_weight.html + # It is trivial to change the expected format to match Keras' ({output_name: weights, ...}) + # see https://github.com/keras-team/keras/issues/4735#issuecomment-267473722 + data["sample_weight"] = sample_weight + data["class_weight"] = None + return data + + +def get_model(meta, compile_kwargs): + inp = keras.layers.Input(shape=(meta["n_features_in_"])) + x1 = keras.layers.Dense(100, activation="relu")(inp) + out_bin = keras.layers.Dense(1, activation="sigmoid")(x1) + out_cat = keras.layers.Dense(meta["n_classes_"][1], activation="softmax")(x1) + model = keras.Model(inputs=inp, outputs=[out_bin, out_cat]) + model.compile( + loss=["binary_crossentropy", "sparse_categorical_crossentropy"], + optimizer=compile_kwargs["optimizer"] + ) + return model + + +class CustomClassifier(KerasClassifier): + + @property + def target_encoder(self): + return MultiOutputTransformer() + + @property + def dataset_transformer(self): + return DatasetTransformer( + output_names=self.model_.output_names, + ) +``` + +Next, we define the data. We'll use `sklearn.datasets.make_blobs` to generate a relatively noisy dataset: + +```python +from sklearn.datasets import make_blobs + + +X, y = make_blobs(centers=3, random_state=0, cluster_std=20) +# make a binary target for "is the value of the first class?" +y_bin = y == y[0] +y = np.column_stack([y_bin, y]) +``` + +Test the model without specifying class weighting: + +```python +clf = CustomClassifier(get_model, epochs=100, verbose=0, random_state=0) +clf.fit(X, y) +y_pred = clf.predict(X) +(_, counts_bin) = np.unique(y_pred[:, 0], return_counts=True) +print(counts_bin) +(_, counts_cat) = np.unique(y_pred[:, 1], return_counts=True) +print(counts_cat) +``` + +As you can see, without `class_weight="balanced"`, our classifier only predicts mainly a single class for the first output. Now with `class_weight="balanced"`: + +```python +clf = CustomClassifier(get_model, class_weight="balanced", epochs=100, verbose=0, random_state=0) +clf.fit(X, y) +y_pred = clf.predict(X) +(_, counts_bin) = np.unique(y_pred[:, 0], return_counts=True) +print(counts_bin) +(_, counts_cat) = np.unique(y_pred[:, 1], return_counts=True) +print(counts_cat) +``` + +Now, we get (mostly) balanced classes. But what if we want to specify our classes manually? You will notice that in when we defined `DatasetTransformer`, we gave it the ability to handle +a list of class weights. For demonstration purposes, we will highly bias towards the second class in each output: + +```python +clf = CustomClassifier(get_model, class_weight=[{0: 0.1, 1: 1}, {0: 0.1, 1: 1, 2: 0.1}], epochs=100, verbose=0, random_state=0) +clf.fit(X, y) +y_pred = clf.predict(X) +(_, counts_bin) = np.unique(y_pred[:, 0], return_counts=True) +print(counts_bin) +(_, counts_cat) = np.unique(y_pred[:, 1], return_counts=True) +print(counts_cat) +``` + +Or mixing the two methods, because our first output is unbalanced but our second is (presumably) balanced: + +```python +clf = CustomClassifier(get_model, class_weight=["balanced", None], epochs=100, verbose=0, random_state=0) +clf.fit(X, y) +y_pred = clf.predict(X) +(_, counts_bin) = np.unique(y_pred[:, 0], return_counts=True) +print(counts_bin) +(_, counts_cat) = np.unique(y_pred[:, 1], return_counts=True) +print(counts_cat) +``` + +## 7. Custom validation dataset + +Although `dataset_transformer` is primarily designed for data transformations, because it returns valid `**kwargs` to fit it can be used for other advanced use cases. +In this example, we use `dataset_transformer` to implement a custom test/train split for Keras' internal validation. We'll use sklearn's +`train_test_split`, but this could be implemented via an arbitrary user function, eg. to ensure balanced class distribution. + +```python +from sklearn.model_selection import train_test_split + + +def get_clf(meta: Dict[str, Any]): + inp = keras.layers.Input(shape=(meta["n_features_in_"],)) + x1 = keras.layers.Dense(100, activation="relu")(inp) + out = keras.layers.Dense(1, activation="sigmoid")(x1) + return keras.Model(inputs=inp, outputs=out) + + +class CustomSplit(BaseEstimator, TransformerMixin): + + def __init__(self, test_size: float): + self.test_size = test_size + + def fit(self, data: Dict[str, Any]) -> "CustomSplit": + return self + + def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: + if self.test_size == 0: + return data + x, y, sw = data["x"], data.get("y", None), data.get("sample_weight", None) + if y is None: + return data + if sw is None: + x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=self.test_size, stratify=y) + validation_data = (x_val, y_val) + sw_train = None + else: + x_train, x_val, y_train, y_val, sw_train, sw_val = train_test_split(x, y, sw, test_size=self.test_size, stratify=y) + validation_data = (x_val, y_val, sw_val) + data["validation_data"] = validation_data + data["x"], data["y"], data["sample_weight"] = x_train, y_train, sw_train + return data + + +class CustomClassifier(KerasClassifier): + + @property + def dataset_transformer(self): + return CustomSplit(test_size=self.validation_split) +``` + +And now lets test with a toy dataset. We specifically choose to make the target strings to show +that with this approach, we can preserve all of the nice data pre-processing that SciKeras does +for us, while still being able to split the final data before passing it to Keras. + +```python +y = np.array(["a"] * 900 + ["b"] * 100) +X = np.array([0] * 900 + [1] * 100).reshape(-1, 1) +``` + +To get a base measurment to compare against, we'll run first with KerasClassifier as a benchmark. + +```python +clf = KerasClassifier( + get_clf, + loss="bce", + metrics=["binary_accuracy"], + verbose=False, + validation_split=0.1, + shuffle=False, + random_state=0, + epochs=10 +) + +clf.fit(X, y) +print(f"binary_accuracy = {clf.history_['binary_accuracy'][-1]}") +print(f"val_binary_accuracy = {clf.history_['val_binary_accuracy'][-1]}") +``` + +We see that we get near zero validation accuracy. Because one of our classes was only found in the tail end of our dataset and we specified `validation_split=0.1`, we validated with a class we had never seen before. + +We could specify `shuffle=True` (this is actually the default), but for highly imbalanced classes, this may not be as good as stratified splitting. + +So lets test our new `CustomClassifier`. + +```python +clf = CustomClassifier( + get_clf, + loss="bce", + metrics=["binary_accuracy"], + verbose=False, + validation_split=0.1, + shuffle=False, + random_state=0, + epochs=10 +) + +clf.fit(X, y) +print(f"binary_accuracy = {clf.history_['binary_accuracy'][-1]}") +print(f"val_binary_accuracy = {clf.history_['val_binary_accuracy'][-1]}") +``` + +Much better! + + +## 8. Dynamically setting batch_size + + +In this tutorial, we use the `data_transformer` interface to implement a dynamic batch_size, similar to sklearn's [MLPClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPClassifier.html). We will implement `batch_size` as `batch_size=min(200, n_samples)`. + +```python +from sklearn.model_selection import train_test_split + + +def check_batch_size(x): + """Check the batch_size used in training. + """ + bs = x.shape[0] + if bs is not None: + print(f"batch_size={bs}") + return x + + +def get_clf(meta: Dict[str, Any]): + inp = keras.layers.Input(shape=(meta["n_features_in_"],)) + x1 = keras.layers.Dense(100, activation="relu")(inp) + x2 = keras.layers.Lambda(check_batch_size)(x1) + out = keras.layers.Dense(1, activation="sigmoid")(x2) + return keras.Model(inputs=inp, outputs=out) + + +class DynamicBatch(BaseEstimator, TransformerMixin): + + def fit(self, data: Dict[str, Any]) -> "DynamicBatch": + return self + + def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: + n_samples = data["x"].shape[0] + data["batch_size"] = min(200, n_samples) + return data + + +class DynamicBatchClassifier(KerasClassifier): + + @property + def dataset_transformer(self): + return DynamicBatch() +``` + +Since this is happening inside SciKeras, this will work even if we are doing cross validation (which adjusts the split according to `cv`). + +```python +from sklearn.model_selection import cross_val_score + +clf = DynamicBatchClassifier( + get_clf, + loss="bce", + verbose=False, + random_state=0 +) + +_ = cross_val_score(clf, X, y, cv=6) # note: 1000 / 6 = 167 +``` + +But if we train with larger inputs, we can hit the cap of 200 we set: + +```python +_ = cross_val_score(clf, X, y, cv=5) +``` diff --git a/scikeras/utils/transformers.py b/scikeras/utils/transformers.py index 8bb87da7d..31fd29a18 100644 --- a/scikeras/utils/transformers.py +++ b/scikeras/utils/transformers.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import tensorflow as tf @@ -7,6 +7,7 @@ from sklearn.exceptions import NotFittedError from sklearn.pipeline import make_pipeline from sklearn.preprocessing import FunctionTransformer, OneHotEncoder, OrdinalEncoder +from sklearn.utils.class_weight import compute_sample_weight from sklearn.utils.multiclass import type_of_target from tensorflow.keras.losses import Loss from tensorflow.python.keras.losses import is_categorical_crossentropy @@ -27,6 +28,11 @@ class TargetReshaper(BaseEstimator, TransformerMixin): def fit(self, y: np.ndarray) -> "TargetReshaper": """Fit the transformer to a target y. + Parameters + ---------- + y : np.ndarray + The target data to be transformed. + Returns ------- TargetReshaper @@ -316,6 +322,11 @@ def fit(self, y: np.ndarray) -> "RegressorTargetEncoder": For RegressorTargetEncoder, this just records the dimensions of y as the expected number of outputs and saves the dtype. + Parameters + ---------- + y : np.ndarray + The target data to be transformed. + Returns ------- RegressorTargetEncoder @@ -383,3 +394,29 @@ def get_metadata(self): "n_outputs_": self.n_outputs_, "n_outputs_expected_": self.n_outputs_expected_, } + + +class ClassWeightDataTransformer(BaseEstimator, TransformerMixin): + """Default dataset_transformer for KerasClassifier. + + This transformer implements handling of the `class_weight` parameter + for single output classifiers. + """ + + def __init__(self, class_weight: Optional[Union[str, Dict[int, float]]] = None): + self.class_weight = class_weight + + def fit( + self, data: Dict[str, Any], dummy: None = None + ) -> "ClassWeightDataTransformer": + return self + + def transform(self, data: Dict[str, Any]) -> Dict[str, Any]: + y, sample_weight = data.get("y", None), data.get("sample_weight", None) + if self.class_weight is None or y is None: + return data + sample_weight = 1 if sample_weight is None else sample_weight + sample_weight *= compute_sample_weight(class_weight=self.class_weight, y=y) + data["sample_weight"] = sample_weight + data["class_weight"] = None + return data diff --git a/scikeras/wrappers.py b/scikeras/wrappers.py index b924c2156..5e0062132 100644 --- a/scikeras/wrappers.py +++ b/scikeras/wrappers.py @@ -15,7 +15,6 @@ from sklearn.metrics import accuracy_score as sklearn_accuracy_score from sklearn.metrics import r2_score as sklearn_r2_score from sklearn.preprocessing import FunctionTransformer -from sklearn.utils.class_weight import compute_sample_weight from sklearn.utils.multiclass import type_of_target from sklearn.utils.validation import _check_sample_weight, check_array, check_X_y from tensorflow.keras import losses as losses_module @@ -33,7 +32,11 @@ unflatten_params, ) from scikeras.utils import loss_name, metric_name -from scikeras.utils.transformers import ClassifierLabelEncoder, RegressorTargetEncoder +from scikeras.utils.transformers import ( + ClassifierLabelEncoder, + ClassWeightDataTransformer, + RegressorTargetEncoder, +) class BaseWrapper(BaseEstimator): @@ -135,8 +138,8 @@ class BaseWrapper(BaseEstimator): "callbacks", "validation_split", "shuffle", - "class_weight", "sample_weight", + "class_weight", "initial_epoch", "validation_steps", "validation_batch_size", @@ -475,16 +478,21 @@ def _fit_keras_model( # collect parameters params = self.get_params() fit_args = route_params(params, destination="fit", pass_filter=self._fit_kwargs) - fit_args["sample_weight"] = sample_weight fit_args["epochs"] = initial_epoch + epochs fit_args["initial_epoch"] = initial_epoch fit_args.update(kwargs) + fit_args["x"] = X + fit_args["y"] = y + fit_args["sample_weight"] = sample_weight + + fit_args = self.dataset_transformer_.transform(fit_args) + if self._random_state is not None: with TFRandomState(self._random_state): - hist = self.model_.fit(x=X, y=y, **fit_args) + hist = self.model_.fit(**fit_args) else: - hist = self.model_.fit(x=X, y=y, **fit_args) + hist = self.model_.fit(**fit_args) if not warm_start or not hasattr(self, "history_") or initial_epoch == 0: self.history_ = defaultdict(list) @@ -535,7 +543,12 @@ def _check_model_compatibility(self, y: np.ndarray) -> None: ) def _validate_data( - self, X=None, y=None, reset: bool = False, y_numeric: bool = False + self, + X=None, + y=None, + sample_weight=None, + reset: bool = False, + y_numeric: bool = False, ) -> Tuple[np.ndarray, Union[np.ndarray, None]]: """Validate input arrays and set or check their meta-parameters. @@ -642,7 +655,9 @@ def _check_array_dtype(arr, force_numeric): n_features_in_, self.__class__.__name__, self.n_features_in_ ) ) - return X, y + if sample_weight is not None: + X, sample_weight = self._validate_sample_weight(X, sample_weight) + return X, y, sample_weight def _type_of_target(self, y: np.ndarray) -> str: return type_of_target(y) @@ -681,6 +696,36 @@ def feature_encoder(self): """ return FunctionTransformer() + @property + def dataset_transformer(self): + """Retrieve a transformer to be applied jointly to the entire + dataset (X, y & sample_weights). + + You can override this method to provide custom transformations. + + It MUST accept a 3 element tuple as it's single input argument + to `fit` and `transform`. `transform` must also output + a 3 element tuple in the same format. + The first element corresponds to X, or as an output from the + transformer, to a `tf.data.Dataset` instance containing + X, y and optionally sample_weights. + The second element corresponds to `y`, and may be None + on the output side always and on the input side when + called from `predict`. + The third element is `sample_weights` which may be None + on the input and output sides. + + Note that `inverse_transform` is never used + and is not required to be implemented. + + Returns + ------- + dataset_transformer + Transformer implementing the sklearn transformer + interface. + """ + return FunctionTransformer() + def fit(self, X, y, sample_weight=None, **kwargs) -> "BaseWrapper": """Constructs a new model with ``model`` & fit the model to ``(X, y)``. @@ -737,7 +782,10 @@ def initialized_(self) -> bool: return hasattr(self, "model_") def _initialize( - self, X: np.ndarray, y: Union[np.ndarray, None] = None + self, + X: np.ndarray, + y: Union[np.ndarray, None] = None, + sample_weight: Union[np.ndarray, None] = None, ) -> Tuple[np.ndarray, np.ndarray]: # Handle random state @@ -754,20 +802,24 @@ def _initialize( # int or None self._random_state = self.random_state - X, y = self._validate_data(X, y, reset=True) + X, y, sample_weight = self._validate_data(X, y, sample_weight, reset=True) self.target_encoder_ = self.target_encoder.fit(y) target_metadata = getattr(self.target_encoder_, "get_metadata", dict)() vars(self).update(**target_metadata) self.feature_encoder_ = self.feature_encoder.fit(X) - feature_meta = getattr(self.feature_encoder, "get_metadata", dict)() + feature_meta = getattr(self.feature_encoder_, "get_metadata", dict)() vars(self).update(**feature_meta) self.model_ = self._build_keras_model() - return X, y + self.dataset_transformer_ = self.dataset_transformer.fit( + dict(x=X, y=y, sample_weight=sample_weight) + ) + + return X, y, sample_weight - def initialize(self, X, y=None) -> "BaseWrapper": + def initialize(self, X, y=None, sample_weight=None) -> "BaseWrapper": """Initialize the model without any fitting. You only need to call this model if you explicitly do not want to do any fitting @@ -788,7 +840,7 @@ def initialize(self, X, y=None) -> "BaseWrapper": BaseWrapper A reference to the BaseWrapper instance for chained calling. """ - self._initialize(X, y) + self._initialize(X, y, sample_weight) return self # to allow chained calls like initialize(...).predict(...) def _fit( @@ -825,17 +877,15 @@ def _fit( """ # Data checks if not ((self.warm_start or warm_start) and self.initialized_): - X, y = self._initialize(X, y) + X, y, sample_weight = self._initialize(X, y, sample_weight) else: - X, y = self._validate_data(X, y) + X, y, sample_weight = self._validate_data(X, y, sample_weight) - if sample_weight is not None: - X, sample_weight = self._validate_sample_weight(X, sample_weight) - - y = self.target_encoder_.transform(y) X = self.feature_encoder_.transform(X) - self._check_model_compatibility(y) + if y is not None: + y = self.target_encoder_.transform(y) + self._check_model_compatibility(y) self._fit_keras_model( X, @@ -897,7 +947,7 @@ def _predict_raw(self, X, **kwargs): "Estimator needs to be fit before `predict` " "can be called" ) # basic input checks - X, _ = self._validate_data(X=X, y=None) + X, _, _ = self._validate_data(X=X) # pre process X X = self.feature_encoder_.transform(X) @@ -908,9 +958,12 @@ def _predict_raw(self, X, **kwargs): params, destination="predict", pass_filter=self._predict_kwargs ) pred_args.update(kwargs) + pred_args["x"] = X + + pred_args = self.dataset_transformer_.transform(pred_args) # predict with Keras model - y_pred = self.model_.predict(X, **pred_args) + y_pred = self.model_.predict(**pred_args) return y_pred @@ -994,7 +1047,7 @@ def score(self, X, y, sample_weight=None) -> float: ) # validate y - _, y = self._validate_data(X=None, y=y) + _, y, _ = self._validate_data(X=None, y=y) # compute Keras model score y_pred = self.predict(X) @@ -1311,6 +1364,41 @@ def target_encoder(self): categories = "auto" if self.classes_ is None else [self.classes_] return ClassifierLabelEncoder(loss=self.loss, categories=categories) + @property + def dataset_transformer(self): + """Retrieve a transformer to be applied jointly to the entire + dataset (X, y & sample_weights). + + By default, KerasClassifier implements ClassWeightDataTransformer, + which embeds class_weight into sample_weight. + + You can override this method to provide custom transformations. + To keep the default class_weight behavior, you can chain your + transfromer and ClassWeightDataTransformer using a Pipeline. + + It MUST accept a 3 element tuple as it's single input argument + to `fit` and `transform`. `transform` must also output + a 3 element tuple in the same format. + The first element corresponds to X, or as an output from the + transformer, to a `tf.data.Dataset` instance containing + X, y and optionally sample_weights. + The second element corresponds to `y`, and may be None + on the output side always and on the input side when + called from `predict`. + The third element is `sample_weights` which may be None + on the input and output sides. + + Note that `inverse_transform` is never used + and is not required to be implemented. + + Returns + ------- + dataset_transformer + Transformer implementing the sklearn transformer + interface. + """ + return ClassWeightDataTransformer(class_weight=self.class_weight) + def initialize(self, X, y) -> "KerasClassifier": """Initialize the model without any fitting. You only need to call this model if you explicitly do not want to do any fitting @@ -1361,9 +1449,6 @@ def fit(self, X, y, sample_weight=None, **kwargs) -> "KerasClassifier": (ex: instance.fit(X,y).transform(X) ) """ self.classes_ = None - if self.class_weight is not None: - sample_weight = 1 if sample_weight is None else sample_weight - sample_weight *= compute_sample_weight(class_weight=self.class_weight, y=y) super().fit(X=X, y=y, sample_weight=sample_weight, **kwargs) return self @@ -1398,9 +1483,6 @@ def partial_fit(self, X, y, classes=None, sample_weight=None) -> "KerasClassifie self.classes_ = ( classes if classes is not None else getattr(self, "classes_", None) ) - if self.class_weight is not None: - sample_weight = 1 if sample_weight is None else sample_weight - sample_weight *= compute_sample_weight(class_weight=self.class_weight, y=y) super().partial_fit(X, y, sample_weight=sample_weight) return self @@ -1581,13 +1663,20 @@ def scorer(y_true, y_pred, **kwargs) -> float: return sklearn_r2_score(y_true, y_pred, **kwargs) def _validate_data( - self, X=None, y=None, reset: bool = False, y_numeric: bool = False + self, + X=None, + y=None, + sample_weight=None, + reset: bool = False, + y_numeric: bool = False, ) -> Tuple[np.ndarray, Union[np.ndarray, None]]: # For regressors, y should ALWAYS be numeric # To enforce this without additional dtype checks, we set `y_numeric=True` # when calling `_validate_data` which will force casting to numeric for # non-numeric data. - return super()._validate_data(X=X, y=y, reset=reset, y_numeric=True) + return super()._validate_data( + X=X, y=y, sample_weight=sample_weight, reset=reset, y_numeric=True + ) @property def target_encoder(self): diff --git a/tests/test_api.py b/tests/test_api.py index 0cf2a4d4e..9b9b0006c 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,10 +1,12 @@ """Tests for Scikit-learn API wrapper.""" import pickle -from typing import Any, Dict +from typing import Any, Dict, Tuple +from unittest.mock import patch import numpy as np import pytest +import tensorflow as tf from sklearn.calibration import CalibratedClassifierCV from sklearn.datasets import load_boston, load_digits, load_iris @@ -16,8 +18,8 @@ ) from sklearn.exceptions import NotFittedError from sklearn.model_selection import GridSearchCV, RandomizedSearchCV -from sklearn.pipeline import Pipeline -from sklearn.preprocessing import StandardScaler +from sklearn.pipeline import Pipeline, make_pipeline +from sklearn.preprocessing import FunctionTransformer, StandardScaler from tensorflow.keras import losses as losses_module from tensorflow.keras import metrics as metrics_module from tensorflow.keras.layers import Conv2D, Dense, Flatten, Input @@ -27,7 +29,7 @@ from tensorflow.python.keras.utils.generic_utils import register_keras_serializable from tensorflow.python.keras.utils.np_utils import to_categorical -from scikeras.wrappers import BaseWrapper, KerasClassifier, KerasRegressor +from scikeras.wrappers import KerasClassifier, KerasRegressor from .mlp_models import dynamic_classifier, dynamic_regressor from .testing_utils import basic_checks @@ -801,3 +803,81 @@ def test_prebuilt_model(self, wrapper): np.testing.assert_allclose(y_pred_keras, y_pred_scikeras) # Check that we are still using the same model object assert est.model_ is m2 + + +class TestDatasetTransformer: + def test_conversion_to_dataset(self): + """Check that the dataset_transformer + interface can return a tf Dataset + """ + inp = Input((1,)) + out = Dense(1, activation="sigmoid")(inp) + m = Model(inp, out) + m.compile(loss="bce") + + def transform(fit_kwargs: Dict[str, Any]): + x = fit_kwargs.pop("x") + y = fit_kwargs.pop("y") if "y" in fit_kwargs else None + sample_weight = ( + fit_kwargs.pop("sample_weight") + if "sample_weight" in fit_kwargs + else None + ) + fit_kwargs["x"] = tf.data.Dataset.from_tensor_slices((x, y, sample_weight)) + return fit_kwargs + + class MyWrapper(KerasClassifier): + @property + def dataset_transformer(self): + return FunctionTransformer(transform) + + est = MyWrapper(m) + X = np.random.random((100, 1)) + y = np.array(["a", "b"] * 50, dtype=str) + fit_orig = m.fit + + def check_fit(**kwargs): + assert isinstance(kwargs["x"], tf.data.Dataset) + assert "y" not in kwargs + assert "sample_weight" not in kwargs + return fit_orig(**kwargs) + + with patch.object(m, "fit", new=check_fit): + est.fit(X, y) + y_pred = est.predict(X) + assert y_pred.dtype == y.dtype + assert y_pred.shape == y.shape + assert set(y_pred).issubset(set(y)) + + def test_pipeline(self): + """Check that the dataset_transformer + interface is compatible with Pipelines + """ + inp = Input((1,)) + out = Dense(1, activation="sigmoid")(inp) + m = Model(inp, out) + m.compile(loss="bce") + + def transform(fit_kwargs: Dict[str, Any]): + x = fit_kwargs.pop("x") + y = fit_kwargs.pop("y") if "y" in fit_kwargs else None + sample_weight = ( + fit_kwargs.pop("sample_weight") + if "sample_weight" in fit_kwargs + else None + ) + fit_kwargs["x"] = tf.data.Dataset.from_tensor_slices((x, y, sample_weight)) + return fit_kwargs + + class MyWrapper(KerasClassifier): + @property + def dataset_transformer(self): + t1 = super().dataset_transformer + t2 = FunctionTransformer(transform) + return make_pipeline(t1, t2) + + est = MyWrapper(m, class_weight="balanced") + X = np.random.random((100, 1)) + y = np.array(["a", "b"] * 50, dtype=str) + + est.fit(X, y)