Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add loss="auto" as the default loss #210

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scikeras/utils/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def fit(self, y: np.ndarray) -> "ClassifierLabelEncoder":
"multiclass-multioutput": FunctionTransformer(),
"multilabel-indicator": FunctionTransformer(),
}
if is_categorical_crossentropy(self.loss):
if target_type == "multiclass" and is_categorical_crossentropy(self.loss):
encoders["multiclass"] = make_pipeline(
TargetReshaper(),
OneHotEncoder(
Expand Down
135 changes: 121 additions & 14 deletions scikeras/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tensorflow.keras import optimizers as optimizers_module
from tensorflow.keras.models import Model
from tensorflow.keras.utils import register_keras_serializable
from tensorflow.python.types.core import Value

from scikeras._utils import (
TFRandomState,
Expand Down Expand Up @@ -345,27 +346,36 @@ def _get_compile_kwargs(self):
compile_kwargs = route_params(
init_params, destination="compile", pass_filter=self._compile_kwargs,
)
compile_kwargs["optimizer"] = _class_from_strings(
compile_kwargs["optimizer"], optimizers_module.get
)
try:
compile_kwargs["optimizer"] = _class_from_strings(
compile_kwargs["optimizer"], optimizers_module.get
)
except ValueError:
pass # unknown optimizer
compile_kwargs["optimizer"] = unflatten_params(
items=compile_kwargs["optimizer"],
params=route_params(
init_params, destination="optimizer", pass_filter=set(), strict=True,
),
)
compile_kwargs["loss"] = _class_from_strings(
compile_kwargs["loss"], losses_module.get
)
try:
compile_kwargs["loss"] = _class_from_strings(
compile_kwargs["loss"], losses_module.get
)
except ValueError:
pass # unknown loss
stsievert marked this conversation as resolved.
Show resolved Hide resolved
compile_kwargs["loss"] = unflatten_params(
items=compile_kwargs["loss"],
params=route_params(
init_params, destination="loss", pass_filter=set(), strict=False,
),
)
compile_kwargs["metrics"] = _class_from_strings(
compile_kwargs["metrics"], metrics_module.get
)
try:
compile_kwargs["metrics"] = _class_from_strings(
compile_kwargs["metrics"], metrics_module.get
)
except ValueError:
pass # unknown loss
compile_kwargs["metrics"] = unflatten_params(
items=compile_kwargs["metrics"],
params=route_params(
Expand Down Expand Up @@ -422,8 +432,14 @@ def _build_keras_model(self):

def _ensure_compiled_model(self) -> None:
# compile model if user gave us an un-compiled model
if not (hasattr(self.model_, "loss") and hasattr(self.model_, "optimizer")):
self.model_.compile(**self._get_compile_kwargs())
if not (
getattr(self.model_, "loss", None)
and getattr(self.model_, "optimizer", None)
adriangb marked this conversation as resolved.
Show resolved Hide resolved
):
self._compile_model(self._get_compile_kwargs())

def _compile_model(self, compile_kwargs: Dict[str, Any]) -> None:
self.model_.compile(**compile_kwargs)

def _fit_keras_model(
self,
Expand Down Expand Up @@ -549,8 +565,14 @@ def _check_model_compatibility(self, y: np.ndarray) -> None:
for x in [self.loss, self.model_.loss]
): # filter out loss list/dicts/etc.
if default_val is not None:
default_val = loss_name(default_val)
given = loss_name(self.loss)
try:
default_val = loss_name(default_val)
except ValueError:
pass
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Todo: figure out what to do here, or even refactor this check like in #208

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think loss_name returning None says "the provided loss has no name/is not recognized."

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to put that change in. I still think this check needs to be refactored

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This worked out. Only small doc and test changes needed.

try:
given = loss_name(self.loss)
except ValueError:
return # idk placeholder
got = loss_name(self.model_.loss)
if given != default_val and got != given:
raise ValueError(
Expand Down Expand Up @@ -1257,7 +1279,7 @@ def __init__(
] = "rmsprop",
loss: Union[
Union[str, tf.keras.losses.Loss, Type[tf.keras.losses.Loss], Callable], None
] = None,
] = "auto",
adriangb marked this conversation as resolved.
Show resolved Hide resolved
metrics: Union[
List[
Union[
Expand Down Expand Up @@ -1310,6 +1332,22 @@ def _type_of_target(self, y: np.ndarray) -> str:
target_type = type_of_target(self.classes_)
return target_type

def _compile_model(self, compile_kwargs: Dict[str, Any]) -> None:
if compile_kwargs["loss"] == "auto":
if len(self.model_.outputs) > 1:
raise ValueError(
'Only single-output models are supported with `loss="auto"`'
)
if self.target_type_ == "binary":
compile_kwargs["loss"] = "binary_crossentropy"
else:
if self.model_.outputs[0].shape[1] == 1:
raise ValueError(
"Multi-class targets require the model to have >1 output unit"
adriangb marked this conversation as resolved.
Show resolved Hide resolved
)
compile_kwargs["loss"] = "categorical_crossentropy"
self.model_.compile(**compile_kwargs)

@staticmethod
def scorer(y_true, y_pred, **kwargs) -> float:
"""Scoring function for KerasClassifier.
Expand Down Expand Up @@ -1611,6 +1649,75 @@ class KerasRegressor(BaseWrapper):
**BaseWrapper._tags,
}

def __init__(
self,
model: Union[None, Callable[..., tf.keras.Model], tf.keras.Model] = None,
*,
build_fn: Union[
None, Callable[..., tf.keras.Model], tf.keras.Model
] = None, # for backwards compatibility
warm_start: bool = False,
random_state: Union[int, np.random.RandomState, None] = None,
optimizer: Union[
str, tf.keras.optimizers.Optimizer, Type[tf.keras.optimizers.Optimizer]
] = "rmsprop",
adriangb marked this conversation as resolved.
Show resolved Hide resolved
loss: Union[
Union[str, tf.keras.losses.Loss, Type[tf.keras.losses.Loss], Callable], None
] = "auto",
metrics: Union[
List[
Union[
str,
tf.keras.metrics.Metric,
Type[tf.keras.metrics.Metric],
Callable,
]
],
None,
] = None,
batch_size: Union[int, None] = None,
validation_batch_size: Union[int, None] = None,
verbose: int = 1,
callbacks: Union[
List[Union[tf.keras.callbacks.Callback, Type[tf.keras.callbacks.Callback]]],
None,
] = None,
validation_split: float = 0.0,
shuffle: bool = True,
run_eagerly: bool = False,
epochs: int = 1,
**kwargs,
):

# Parse hardcoded params
self.model = model
self.build_fn = build_fn
self.warm_start = warm_start
self.random_state = random_state
self.optimizer = optimizer
self.loss = loss
self.metrics = metrics
self.batch_size = batch_size
self.validation_batch_size = validation_batch_size
self.verbose = verbose
self.callbacks = callbacks
self.validation_split = validation_split
self.shuffle = shuffle
self.run_eagerly = run_eagerly
self.epochs = epochs

# Unpack kwargs
vars(self).update(**kwargs)

# Save names of kwargs into set
if kwargs:
self._user_params = set(kwargs)
adriangb marked this conversation as resolved.
Show resolved Hide resolved

def _compile_model(self, compile_kwargs: Dict[str, Any]) -> None:
if compile_kwargs["loss"] == "auto":
compile_kwargs["loss"] = "mean_squared_error"
self.model_.compile(**compile_kwargs)

@staticmethod
def scorer(y_true, y_pred, **kwargs) -> float:
"""Scoring function for KerasRegressor.
Expand Down
21 changes: 12 additions & 9 deletions tests/mlp_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,16 @@ def dynamic_classifier(
for layer_size in hidden_layer_sizes:
hidden = Dense(layer_size, activation="relu")(hidden)

if compile_kwargs["loss"] == "auto":
loss = None
else:
loss = compile_kwargs["loss"]
adriangb marked this conversation as resolved.
Show resolved Hide resolved

if target_type_ == "binary":
compile_kwargs["loss"] = compile_kwargs["loss"] or "binary_crossentropy"
compile_kwargs["loss"] = loss or "binary_crossentropy"
out = [Dense(1, activation="sigmoid")(hidden)]
elif target_type_ == "multilabel-indicator":
compile_kwargs["loss"] = compile_kwargs["loss"] or "binary_crossentropy"
compile_kwargs["loss"] = loss or "binary_crossentropy"
if isinstance(n_classes_, list):
out = [
Dense(1, activation="sigmoid")(hidden)
Expand All @@ -39,13 +44,11 @@ def dynamic_classifier(
else:
out = Dense(n_classes_, activation="softmax")(hidden)
elif target_type_ == "multiclass-multioutput":
compile_kwargs["loss"] = compile_kwargs["loss"] or "binary_crossentropy"
compile_kwargs["loss"] = loss or "binary_crossentropy"
out = [Dense(n, activation="softmax")(hidden) for n in n_classes_]
else:
# multiclass
compile_kwargs["loss"] = (
compile_kwargs["loss"] or "sparse_categorical_crossentropy"
)
compile_kwargs["loss"] = loss or "sparse_categorical_crossentropy"
out = [Dense(n_classes_, activation="softmax")(hidden)]

model = Model(inp, out)
Expand All @@ -60,13 +63,13 @@ def dynamic_regressor(
meta: Optional[Dict[str, Any]] = None,
compile_kwargs: Optional[Dict[str, Any]] = None,
) -> Model:
"""Creates a basic MLP regressor dynamically.
"""
"""Creates a basic MLP regressor dynamically."""
# get parameters
n_features_in_ = meta["n_features_in_"]
n_outputs_ = meta["n_outputs_"]

compile_kwargs["loss"] = compile_kwargs["loss"] or "mse"
if compile_kwargs["loss"] == "auto":
compile_kwargs["loss"] = "mean_squared_error"

inp = Input(shape=(n_features_in_,))

Expand Down