Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement batch_size=-1 #194

Merged
merged 7 commits into from
Feb 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions scikeras/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ class BaseWrapper(BaseEstimator):
If False, subsequent fit calls will reset the entire model.
This has no impact on partial_fit, which always trains
for a single epoch starting from the current epoch.
batch_size : Union[int, None], default None
Number of samples per gradient update.
This will be applied to both `fit` and `predict`. To specify different numbers,
pass `fit__batch_size=32` and `predict__batch_size=1000` (for example).
To auto-adjust the batch size to use all samples, pass `batch_size=-1`.

Attributes
----------
Expand Down Expand Up @@ -212,6 +217,7 @@ def __init__(
None,
] = None,
batch_size: Union[int, None] = None,
validation_batch_size: Union[int, None] = None,
verbose: int = 1,
callbacks: Union[
List[Union[tf.keras.callbacks.Callback, Type[tf.keras.callbacks.Callback]]],
Expand All @@ -233,6 +239,7 @@ def __init__(
self.loss = loss
self.metrics = metrics
self.batch_size = batch_size
self.validation_batch_size = validation_batch_size
self.verbose = verbose
self.callbacks = callbacks
self.validation_split = validation_split
Expand Down Expand Up @@ -487,6 +494,15 @@ def _fit_keras_model(
fit_args["epochs"] = initial_epoch + epochs
fit_args["initial_epoch"] = initial_epoch
fit_args.update(kwargs)
for bs_kwarg in ("batch_size", "validation_batch_size"):
if bs_kwarg in fit_args:
if fit_args[bs_kwarg] == -1:
try:
fit_args[bs_kwarg] = X.shape[0]
except AttributeError:
raise ValueError(
f"`{bs_kwarg}=-1` requires that `X` implement `shape`"
)

if self._random_state is not None:
with TFRandomState(self._random_state):
Expand Down Expand Up @@ -918,9 +934,17 @@ def _predict_raw(self, X, **kwargs):
params, destination="predict", pass_filter=self._predict_kwargs
)
pred_args.update(kwargs)
if "batch_size" in pred_args:
if pred_args["batch_size"] == -1:
try:
pred_args["batch_size"] = X.shape[0]
except AttributeError:
raise ValueError(
"`batch_size=-1` requires that `X` implement `shape`"
)

# predict with Keras model
y_pred = self.model_.predict(X, **pred_args)
y_pred = self.model_.predict(x=X, **pred_args)

return y_pred

Expand Down Expand Up @@ -1135,6 +1159,11 @@ class KerasClassifier(BaseWrapper):
If False, subsequent fit calls will reset the entire model.
This has no impact on partial_fit, which always trains
for a single epoch starting from the current epoch.
batch_size : Union[int, None], default None
Number of samples per gradient update.
This will be applied to both `fit` and `predict`. To specify different numbers,
pass `fit__batch_size=32` and `predict__batch_size=1000` (for example).
To auto-adjust the batch size to use all samples, pass `batch_size=-1`.
class_weight : Union[Dict[Any, float], str, None], default None
Weights associated with classes in the form ``{class_label: weight}``.
If not given, all classes are supposed to have weight one.
Expand Down Expand Up @@ -1241,6 +1270,7 @@ def __init__(
None,
] = None,
batch_size: Union[int, None] = None,
validation_batch_size: Union[int, None] = None,
verbose: int = 1,
callbacks: Union[
List[Union[tf.keras.callbacks.Callback, Type[tf.keras.callbacks.Callback]]],
Expand All @@ -1262,6 +1292,7 @@ def __init__(
loss=loss,
metrics=metrics,
batch_size=batch_size,
validation_batch_size=validation_batch_size,
verbose=verbose,
callbacks=callbacks,
validation_split=validation_split,
Expand Down Expand Up @@ -1473,32 +1504,33 @@ class KerasRegressor(BaseWrapper):
must return a compiled instance of a Keras Model
to be used by `fit`, `predict`, etc.
If None, you must implement ``_keras_build_fn``.

optimizer : Union[str, tf.keras.optimizers.Optimizer, Type[tf.keras.optimizers.Optimizer]], default "rmsprop"
This can be a string for Keras' built in optimizers,
an instance of tf.keras.optimizers.Optimizer
or a class inheriting from tf.keras.optimizers.Optimizer.
Only strings and classes support parameter routing.

loss : Union[Union[str, tf.keras.losses.Loss, Type[tf.keras.losses.Loss], Callable], None], default None
The loss function to use for training.
This can be a string for Keras' built in losses,
an instance of tf.keras.losses.Loss
or a class inheriting from tf.keras.losses.Loss .
Only strings and classes support parameter routing.

random_state : Union[int, np.random.RandomState, None], default None
Set the Tensorflow random number generators to a
reproducible deterministic state using this seed.
Pass an int for reproducible results across multiple
function calls.

warm_start : bool, default False
If True, subsequent calls to fit will _not_ reset
the model parameters but *will* reset the epoch to zero.
If False, subsequent fit calls will reset the entire model.
This has no impact on partial_fit, which always trains
for a single epoch starting from the current epoch.
batch_size : Union[int, None], default None
Number of samples per gradient update.
This will be applied to both `fit` and `predict`. To specify different numbers,
pass `fit__batch_size=32` and `predict__batch_size=1000` (for example).
To auto-adjust the batch size to use all samples, pass `batch_size=-1`.

Attributes
----------
Expand Down
105 changes: 105 additions & 0 deletions tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sklearn.base import clone
from sklearn.datasets import make_blobs, make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import FunctionTransformer

from scikeras.wrappers import KerasClassifier, KerasRegressor

Expand Down Expand Up @@ -320,3 +321,107 @@ def test_kwargs(wrapper, builder):
or hasattr(est, "fit__" + k)
or hasattr(est, "predict__" + k)
)


@pytest.mark.parametrize("length", (10, 100))
@pytest.mark.parametrize("prefix", ("", "fit__"))
@pytest.mark.parametrize("base", ("validation_batch_size", "batch_size"))
def test_batch_size_all_fit(length, prefix, base):

kw = prefix + base

y = np.random.random((length,))
X = y.reshape((-1, 1))
est = KerasRegressor(dynamic_regressor, hidden_layer_sizes=[], **{kw: -1})

est.initialize(X, y)

fit_orig = est.model_.fit

def check_batch_size(**kwargs):
assert kwargs[base] == X.shape[0]
return fit_orig(**kwargs)

with mock.patch.object(est.model_, "fit", new=check_batch_size):
est.fit(X, y)


@pytest.mark.parametrize("length", (10, 100))
@pytest.mark.parametrize("prefix", ("", "predict__"))
@pytest.mark.parametrize("base", ("batch_size",))
def test_batch_size_all_predict(length, prefix, base):

kw = prefix + base

y = np.random.random((length,))
X = y.reshape((-1, 1))
est = KerasRegressor(dynamic_regressor, hidden_layer_sizes=[], **{kw: -1})

est.fit(X, y)

pred_orig = est.model_.predict

def check_batch_size(**kwargs):
assert kwargs[base] == X.shape[0]
return pred_orig(**kwargs)

with mock.patch.object(est.model_, "predict", new=check_batch_size):
est.predict(X)


@pytest.mark.parametrize("length", (10, 100))
@pytest.mark.parametrize("prefix", ("", "fit__"))
@pytest.mark.parametrize("base", ("validation_batch_size", "batch_size"))
def test_batch_size_all_fit(length, prefix, base):

kw = prefix + base

y = np.random.random((length,))
X = y.reshape((-1, 1))
est = KerasRegressor(dynamic_regressor, hidden_layer_sizes=[], **{kw: -1})

est.initialize(X, y)

fit_orig = est.model_.fit

def check_batch_size(**kwargs):
assert kwargs[base] == X.shape[0]
return fit_orig(**kwargs)

with mock.patch.object(est.model_, "fit", new=check_batch_size):
est.fit(X, y)


@pytest.mark.parametrize("prefix", ("", "fit__"))
@pytest.mark.parametrize("base", ("validation_batch_size", "batch_size"))
def test_batch_size_all_fit_non_array(prefix, base):

kw = prefix + base

class CustomReg(KerasRegressor):
@property
def feature_encoder(self):
return FunctionTransformer(lambda x: [x])

y = np.random.random((100,))
X = y.reshape((-1, 1))
est = CustomReg(dynamic_regressor, hidden_layer_sizes=[], **{kw: -1})

with pytest.raises(ValueError, match="requires that `X` implement `shape`"):
est.fit(X, y)


def test_batch_size_all_predict_non_array():
class CustomReg(KerasRegressor):
@property
def feature_encoder(self):
return FunctionTransformer(lambda x: [x])

y = np.random.random((100,))
X = y.reshape((-1, 1))
est = CustomReg(dynamic_regressor, hidden_layer_sizes=[], predict__batch_size=-1)

est.fit(X, y)

with pytest.raises(ValueError, match="requires that `X` implement `shape`"):
est.predict(X)