Skip to content

Commit

Permalink
Implement batch_size=-1 (#194)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb authored Feb 26, 2021
1 parent d83bffa commit d1a9d12
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 5 deletions.
42 changes: 37 additions & 5 deletions scikeras/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ class BaseWrapper(BaseEstimator):
If False, subsequent fit calls will reset the entire model.
This has no impact on partial_fit, which always trains
for a single epoch starting from the current epoch.
batch_size : Union[int, None], default None
Number of samples per gradient update.
This will be applied to both `fit` and `predict`. To specify different numbers,
pass `fit__batch_size=32` and `predict__batch_size=1000` (for example).
To auto-adjust the batch size to use all samples, pass `batch_size=-1`.
Attributes
----------
Expand Down Expand Up @@ -212,6 +217,7 @@ def __init__(
None,
] = None,
batch_size: Union[int, None] = None,
validation_batch_size: Union[int, None] = None,
verbose: int = 1,
callbacks: Union[
List[Union[tf.keras.callbacks.Callback, Type[tf.keras.callbacks.Callback]]],
Expand All @@ -233,6 +239,7 @@ def __init__(
self.loss = loss
self.metrics = metrics
self.batch_size = batch_size
self.validation_batch_size = validation_batch_size
self.verbose = verbose
self.callbacks = callbacks
self.validation_split = validation_split
Expand Down Expand Up @@ -487,6 +494,15 @@ def _fit_keras_model(
fit_args["epochs"] = initial_epoch + epochs
fit_args["initial_epoch"] = initial_epoch
fit_args.update(kwargs)
for bs_kwarg in ("batch_size", "validation_batch_size"):
if bs_kwarg in fit_args:
if fit_args[bs_kwarg] == -1:
try:
fit_args[bs_kwarg] = X.shape[0]
except AttributeError:
raise ValueError(
f"`{bs_kwarg}=-1` requires that `X` implement `shape`"
)

if self._random_state is not None:
with TFRandomState(self._random_state):
Expand Down Expand Up @@ -918,9 +934,17 @@ def _predict_raw(self, X, **kwargs):
params, destination="predict", pass_filter=self._predict_kwargs
)
pred_args.update(kwargs)
if "batch_size" in pred_args:
if pred_args["batch_size"] == -1:
try:
pred_args["batch_size"] = X.shape[0]
except AttributeError:
raise ValueError(
"`batch_size=-1` requires that `X` implement `shape`"
)

# predict with Keras model
y_pred = self.model_.predict(X, **pred_args)
y_pred = self.model_.predict(x=X, **pred_args)

return y_pred

Expand Down Expand Up @@ -1135,6 +1159,11 @@ class KerasClassifier(BaseWrapper):
If False, subsequent fit calls will reset the entire model.
This has no impact on partial_fit, which always trains
for a single epoch starting from the current epoch.
batch_size : Union[int, None], default None
Number of samples per gradient update.
This will be applied to both `fit` and `predict`. To specify different numbers,
pass `fit__batch_size=32` and `predict__batch_size=1000` (for example).
To auto-adjust the batch size to use all samples, pass `batch_size=-1`.
class_weight : Union[Dict[Any, float], str, None], default None
Weights associated with classes in the form ``{class_label: weight}``.
If not given, all classes are supposed to have weight one.
Expand Down Expand Up @@ -1241,6 +1270,7 @@ def __init__(
None,
] = None,
batch_size: Union[int, None] = None,
validation_batch_size: Union[int, None] = None,
verbose: int = 1,
callbacks: Union[
List[Union[tf.keras.callbacks.Callback, Type[tf.keras.callbacks.Callback]]],
Expand All @@ -1262,6 +1292,7 @@ def __init__(
loss=loss,
metrics=metrics,
batch_size=batch_size,
validation_batch_size=validation_batch_size,
verbose=verbose,
callbacks=callbacks,
validation_split=validation_split,
Expand Down Expand Up @@ -1473,32 +1504,33 @@ class KerasRegressor(BaseWrapper):
must return a compiled instance of a Keras Model
to be used by `fit`, `predict`, etc.
If None, you must implement ``_keras_build_fn``.
optimizer : Union[str, tf.keras.optimizers.Optimizer, Type[tf.keras.optimizers.Optimizer]], default "rmsprop"
This can be a string for Keras' built in optimizers,
an instance of tf.keras.optimizers.Optimizer
or a class inheriting from tf.keras.optimizers.Optimizer.
Only strings and classes support parameter routing.
loss : Union[Union[str, tf.keras.losses.Loss, Type[tf.keras.losses.Loss], Callable], None], default None
The loss function to use for training.
This can be a string for Keras' built in losses,
an instance of tf.keras.losses.Loss
or a class inheriting from tf.keras.losses.Loss .
Only strings and classes support parameter routing.
random_state : Union[int, np.random.RandomState, None], default None
Set the Tensorflow random number generators to a
reproducible deterministic state using this seed.
Pass an int for reproducible results across multiple
function calls.
warm_start : bool, default False
If True, subsequent calls to fit will _not_ reset
the model parameters but *will* reset the epoch to zero.
If False, subsequent fit calls will reset the entire model.
This has no impact on partial_fit, which always trains
for a single epoch starting from the current epoch.
batch_size : Union[int, None], default None
Number of samples per gradient update.
This will be applied to both `fit` and `predict`. To specify different numbers,
pass `fit__batch_size=32` and `predict__batch_size=1000` (for example).
To auto-adjust the batch size to use all samples, pass `batch_size=-1`.
Attributes
----------
Expand Down
105 changes: 105 additions & 0 deletions tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sklearn.base import clone
from sklearn.datasets import make_blobs, make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import FunctionTransformer

from scikeras.wrappers import KerasClassifier, KerasRegressor

Expand Down Expand Up @@ -320,3 +321,107 @@ def test_kwargs(wrapper, builder):
or hasattr(est, "fit__" + k)
or hasattr(est, "predict__" + k)
)


@pytest.mark.parametrize("length", (10, 100))
@pytest.mark.parametrize("prefix", ("", "fit__"))
@pytest.mark.parametrize("base", ("validation_batch_size", "batch_size"))
def test_batch_size_all_fit(length, prefix, base):

kw = prefix + base

y = np.random.random((length,))
X = y.reshape((-1, 1))
est = KerasRegressor(dynamic_regressor, hidden_layer_sizes=[], **{kw: -1})

est.initialize(X, y)

fit_orig = est.model_.fit

def check_batch_size(**kwargs):
assert kwargs[base] == X.shape[0]
return fit_orig(**kwargs)

with mock.patch.object(est.model_, "fit", new=check_batch_size):
est.fit(X, y)


@pytest.mark.parametrize("length", (10, 100))
@pytest.mark.parametrize("prefix", ("", "predict__"))
@pytest.mark.parametrize("base", ("batch_size",))
def test_batch_size_all_predict(length, prefix, base):

kw = prefix + base

y = np.random.random((length,))
X = y.reshape((-1, 1))
est = KerasRegressor(dynamic_regressor, hidden_layer_sizes=[], **{kw: -1})

est.fit(X, y)

pred_orig = est.model_.predict

def check_batch_size(**kwargs):
assert kwargs[base] == X.shape[0]
return pred_orig(**kwargs)

with mock.patch.object(est.model_, "predict", new=check_batch_size):
est.predict(X)


@pytest.mark.parametrize("length", (10, 100))
@pytest.mark.parametrize("prefix", ("", "fit__"))
@pytest.mark.parametrize("base", ("validation_batch_size", "batch_size"))
def test_batch_size_all_fit(length, prefix, base):

kw = prefix + base

y = np.random.random((length,))
X = y.reshape((-1, 1))
est = KerasRegressor(dynamic_regressor, hidden_layer_sizes=[], **{kw: -1})

est.initialize(X, y)

fit_orig = est.model_.fit

def check_batch_size(**kwargs):
assert kwargs[base] == X.shape[0]
return fit_orig(**kwargs)

with mock.patch.object(est.model_, "fit", new=check_batch_size):
est.fit(X, y)


@pytest.mark.parametrize("prefix", ("", "fit__"))
@pytest.mark.parametrize("base", ("validation_batch_size", "batch_size"))
def test_batch_size_all_fit_non_array(prefix, base):

kw = prefix + base

class CustomReg(KerasRegressor):
@property
def feature_encoder(self):
return FunctionTransformer(lambda x: [x])

y = np.random.random((100,))
X = y.reshape((-1, 1))
est = CustomReg(dynamic_regressor, hidden_layer_sizes=[], **{kw: -1})

with pytest.raises(ValueError, match="requires that `X` implement `shape`"):
est.fit(X, y)


def test_batch_size_all_predict_non_array():
class CustomReg(KerasRegressor):
@property
def feature_encoder(self):
return FunctionTransformer(lambda x: [x])

y = np.random.random((100,))
X = y.reshape((-1, 1))
est = CustomReg(dynamic_regressor, hidden_layer_sizes=[], predict__batch_size=-1)

est.fit(X, y)

with pytest.raises(ValueError, match="requires that `X` implement `shape`"):
est.predict(X)

0 comments on commit d1a9d12

Please sign in to comment.