-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add loss="auto" as the default loss #210
base: master
Are you sure you want to change the base?
Changes from all commits
6f22ebf
426e18a
3d775df
63dfd27
1214c1a
885e2a5
c2ad938
8ab75e1
c386ac7
c8419f1
89a7b22
d8bcb95
0d8c2c3
c7b567f
ca868f5
8d50bf9
66e1958
6e87dd9
dde0112
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
=================================== | ||
Advanced Usage of SciKeras Wrappers | ||
=================================== | ||
============== | ||
Advanced Usage | ||
============== | ||
|
||
Wrapper Classes | ||
--------------- | ||
|
@@ -128,6 +128,43 @@ offer an easy way to compile and tune compilation parameters. Examples: | |
In all cases, returning an un-compiled model is equivalent to | ||
calling ``model.compile(**compile_kwargs)`` within ``model_build_fn``. | ||
|
||
.. _loss-selection: | ||
|
||
Loss selection | ||
++++++++++++++ | ||
|
||
If you do not explicitly define a loss, SciKeras attempts to find a loss | ||
that matches the type of target (see :py:func:`sklearn.utils.multiclass.type_of_target`). | ||
|
||
For guidance selecting losses in Keras, please see Jason Brownlee's | ||
excellent article `How to Choose Loss Functions When Training Deep Learning Neural Networks`_ | ||
as well as `Keras Losses docs`_. | ||
|
||
Default losses are selected as follows: | ||
|
||
Classification | ||
.............. | ||
|
||
+-----------+-----------+----------+---------------------------------+ | ||
| # outputs | # classes | encoding | loss | | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm confused by this table. Let's say I have two classes, one "output," and I don't know my "encoding" (I'm not sure a naive user would know what that means). What loss is chosen? Maybe it'd be simpler to say "KerasClassifier has There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm only for using I almost prefer this documentation:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I totally agree. Reading over this PR again a couple weeks after writing it, even I get confused.
I think we tried this before. I don't remember the conclusion of those discussions (although I can dig it up), but off the top of my head I think the biggest issue is that new users will copy an example model from a tutorial, many of which do binary classification using a single neuron, or other incompatible architectures. Another common use case is one-hot encoded targets, which Do you think we can just introspect the model and check if the number of neurons matches the number of classes (and that it is a single-output problem) and raise an error (or maybe a warning) to rescue users from facing whatever cryptic error TF would throw? In other words, with a good enough error message, can we support only the small subset of model architectures that work with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I recall introspecting the model to see what loss value should be used, but trying to abstract too much away from the user (and plus it got too complicated). I think the new loss for
Yeah, I had the same idea. If I were developing this library, I think I'd have
I think both of these should be exceptions. If so, I'd make it clear how to clear how to adapt to BaseWrapper.
I think a clear documentation note would resolve this, especially with good error catching. keras.io examples
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think #210 (comment) is at at least worth exploring (again). I'll open a new PR to test out #210 (comment) and to avoid loosing the changes here into the git history, and also because the changes are going to be pretty unrelated. Thank you for following up on this PR 😄 |
||
+===========+===========+==========+=================================+ | ||
| 1 | <= 2 | any | binary crossentropy | | ||
+-----------+-----------+----------+---------------------------------+ | ||
| 1 | >=2 | labels | sparse categorical crossentropy | | ||
+-----------+-----------+----------+---------------------------------+ | ||
| 1 | >=2 | one-hot | unsupported | | ||
+-----------+-----------+----------+---------------------------------+ | ||
| > 1 | -- | -- | unsupported | | ||
+-----------+-----------+----------+---------------------------------+ | ||
|
||
Note that SciKeras will not automatically infer the loss for one-hot encoded targets, | ||
you would need to explicitly specify `loss="categorical_crossentropy"`. | ||
|
||
Regression | ||
.......... | ||
|
||
Regression always defaults to mean squared error. | ||
For multi-output models, Keras will use the sum of each output's loss. | ||
|
||
Arguments to ``model_build_fn`` | ||
------------------------------- | ||
|
@@ -287,3 +324,7 @@ and :class:`scikeras.wrappers.KerasRegressor` respectively. To override these sc | |
.. _Keras Callbacks docs: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks | ||
|
||
.. _Keras Metrics docs: https://www.tensorflow.org/api_docs/python/tf/keras/metrics | ||
|
||
.. _Keras Losses docs: https://www.tensorflow.org/api_docs/python/tf/keras/losses | ||
|
||
.. _How to Choose Loss Functions When Training Deep Learning Neural Networks: https://machinelearningmastery.com/how-to-choose-loss-functions-when-training-deep-learning-neural-networks/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this section could use some use examples, and clarification of what "output" and "encoding" mean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.