diff --git a/scikeras/wrappers.py b/scikeras/wrappers.py index 4be457d98..b36a55666 100644 --- a/scikeras/wrappers.py +++ b/scikeras/wrappers.py @@ -80,6 +80,11 @@ class BaseWrapper(BaseEstimator): If False, subsequent fit calls will reset the entire model. This has no impact on partial_fit, which always trains for a single epoch starting from the current epoch. + batch_size : Union[int, None], default None + Number of samples per gradient update. + This will be applied to both `fit` and `predict`. To specify different numbers, + pass `fit__batch_size=32` and `predict__batch_size=1000` (for example). + To auto-adjust the batch size to use all samples, pass `batch_size=-1`. Attributes ---------- @@ -212,6 +217,7 @@ def __init__( None, ] = None, batch_size: Union[int, None] = None, + validation_batch_size: Union[int, None] = None, verbose: int = 1, callbacks: Union[ List[Union[tf.keras.callbacks.Callback, Type[tf.keras.callbacks.Callback]]], @@ -233,6 +239,7 @@ def __init__( self.loss = loss self.metrics = metrics self.batch_size = batch_size + self.validation_batch_size = validation_batch_size self.verbose = verbose self.callbacks = callbacks self.validation_split = validation_split @@ -487,6 +494,15 @@ def _fit_keras_model( fit_args["epochs"] = initial_epoch + epochs fit_args["initial_epoch"] = initial_epoch fit_args.update(kwargs) + for bs_kwarg in ("batch_size", "validation_batch_size"): + if bs_kwarg in fit_args: + if fit_args[bs_kwarg] == -1: + try: + fit_args[bs_kwarg] = X.shape[0] + except AttributeError: + raise ValueError( + f"`{bs_kwarg}=-1` requires that `X` implement `shape`" + ) if self._random_state is not None: with TFRandomState(self._random_state): @@ -918,9 +934,17 @@ def _predict_raw(self, X, **kwargs): params, destination="predict", pass_filter=self._predict_kwargs ) pred_args.update(kwargs) + if "batch_size" in pred_args: + if pred_args["batch_size"] == -1: + try: + pred_args["batch_size"] = X.shape[0] + except AttributeError: + raise ValueError( + "`batch_size=-1` requires that `X` implement `shape`" + ) # predict with Keras model - y_pred = self.model_.predict(X, **pred_args) + y_pred = self.model_.predict(x=X, **pred_args) return y_pred @@ -1135,6 +1159,11 @@ class KerasClassifier(BaseWrapper): If False, subsequent fit calls will reset the entire model. This has no impact on partial_fit, which always trains for a single epoch starting from the current epoch. + batch_size : Union[int, None], default None + Number of samples per gradient update. + This will be applied to both `fit` and `predict`. To specify different numbers, + pass `fit__batch_size=32` and `predict__batch_size=1000` (for example). + To auto-adjust the batch size to use all samples, pass `batch_size=-1`. class_weight : Union[Dict[Any, float], str, None], default None Weights associated with classes in the form ``{class_label: weight}``. If not given, all classes are supposed to have weight one. @@ -1241,6 +1270,7 @@ def __init__( None, ] = None, batch_size: Union[int, None] = None, + validation_batch_size: Union[int, None] = None, verbose: int = 1, callbacks: Union[ List[Union[tf.keras.callbacks.Callback, Type[tf.keras.callbacks.Callback]]], @@ -1262,6 +1292,7 @@ def __init__( loss=loss, metrics=metrics, batch_size=batch_size, + validation_batch_size=validation_batch_size, verbose=verbose, callbacks=callbacks, validation_split=validation_split, @@ -1473,32 +1504,33 @@ class KerasRegressor(BaseWrapper): must return a compiled instance of a Keras Model to be used by `fit`, `predict`, etc. If None, you must implement ``_keras_build_fn``. - optimizer : Union[str, tf.keras.optimizers.Optimizer, Type[tf.keras.optimizers.Optimizer]], default "rmsprop" This can be a string for Keras' built in optimizers, an instance of tf.keras.optimizers.Optimizer or a class inheriting from tf.keras.optimizers.Optimizer. Only strings and classes support parameter routing. - loss : Union[Union[str, tf.keras.losses.Loss, Type[tf.keras.losses.Loss], Callable], None], default None The loss function to use for training. This can be a string for Keras' built in losses, an instance of tf.keras.losses.Loss or a class inheriting from tf.keras.losses.Loss . Only strings and classes support parameter routing. - random_state : Union[int, np.random.RandomState, None], default None Set the Tensorflow random number generators to a reproducible deterministic state using this seed. Pass an int for reproducible results across multiple function calls. - warm_start : bool, default False If True, subsequent calls to fit will _not_ reset the model parameters but *will* reset the epoch to zero. If False, subsequent fit calls will reset the entire model. This has no impact on partial_fit, which always trains for a single epoch starting from the current epoch. + batch_size : Union[int, None], default None + Number of samples per gradient update. + This will be applied to both `fit` and `predict`. To specify different numbers, + pass `fit__batch_size=32` and `predict__batch_size=1000` (for example). + To auto-adjust the batch size to use all samples, pass `batch_size=-1`. Attributes ---------- diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 2a0670a46..25b8b53de 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -8,6 +8,7 @@ from sklearn.base import clone from sklearn.datasets import make_blobs, make_classification from sklearn.model_selection import train_test_split +from sklearn.preprocessing import FunctionTransformer from scikeras.wrappers import KerasClassifier, KerasRegressor @@ -320,3 +321,107 @@ def test_kwargs(wrapper, builder): or hasattr(est, "fit__" + k) or hasattr(est, "predict__" + k) ) + + +@pytest.mark.parametrize("length", (10, 100)) +@pytest.mark.parametrize("prefix", ("", "fit__")) +@pytest.mark.parametrize("base", ("validation_batch_size", "batch_size")) +def test_batch_size_all_fit(length, prefix, base): + + kw = prefix + base + + y = np.random.random((length,)) + X = y.reshape((-1, 1)) + est = KerasRegressor(dynamic_regressor, hidden_layer_sizes=[], **{kw: -1}) + + est.initialize(X, y) + + fit_orig = est.model_.fit + + def check_batch_size(**kwargs): + assert kwargs[base] == X.shape[0] + return fit_orig(**kwargs) + + with mock.patch.object(est.model_, "fit", new=check_batch_size): + est.fit(X, y) + + +@pytest.mark.parametrize("length", (10, 100)) +@pytest.mark.parametrize("prefix", ("", "predict__")) +@pytest.mark.parametrize("base", ("batch_size",)) +def test_batch_size_all_predict(length, prefix, base): + + kw = prefix + base + + y = np.random.random((length,)) + X = y.reshape((-1, 1)) + est = KerasRegressor(dynamic_regressor, hidden_layer_sizes=[], **{kw: -1}) + + est.fit(X, y) + + pred_orig = est.model_.predict + + def check_batch_size(**kwargs): + assert kwargs[base] == X.shape[0] + return pred_orig(**kwargs) + + with mock.patch.object(est.model_, "predict", new=check_batch_size): + est.predict(X) + + +@pytest.mark.parametrize("length", (10, 100)) +@pytest.mark.parametrize("prefix", ("", "fit__")) +@pytest.mark.parametrize("base", ("validation_batch_size", "batch_size")) +def test_batch_size_all_fit(length, prefix, base): + + kw = prefix + base + + y = np.random.random((length,)) + X = y.reshape((-1, 1)) + est = KerasRegressor(dynamic_regressor, hidden_layer_sizes=[], **{kw: -1}) + + est.initialize(X, y) + + fit_orig = est.model_.fit + + def check_batch_size(**kwargs): + assert kwargs[base] == X.shape[0] + return fit_orig(**kwargs) + + with mock.patch.object(est.model_, "fit", new=check_batch_size): + est.fit(X, y) + + +@pytest.mark.parametrize("prefix", ("", "fit__")) +@pytest.mark.parametrize("base", ("validation_batch_size", "batch_size")) +def test_batch_size_all_fit_non_array(prefix, base): + + kw = prefix + base + + class CustomReg(KerasRegressor): + @property + def feature_encoder(self): + return FunctionTransformer(lambda x: [x]) + + y = np.random.random((100,)) + X = y.reshape((-1, 1)) + est = CustomReg(dynamic_regressor, hidden_layer_sizes=[], **{kw: -1}) + + with pytest.raises(ValueError, match="requires that `X` implement `shape`"): + est.fit(X, y) + + +def test_batch_size_all_predict_non_array(): + class CustomReg(KerasRegressor): + @property + def feature_encoder(self): + return FunctionTransformer(lambda x: [x]) + + y = np.random.random((100,)) + X = y.reshape((-1, 1)) + est = CustomReg(dynamic_regressor, hidden_layer_sizes=[], predict__batch_size=-1) + + est.fit(X, y) + + with pytest.raises(ValueError, match="requires that `X` implement `shape`"): + est.predict(X)