-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can't pickle trained model with callback to TensorBoard #236
Comments
Thank you for the clean reproducible example. Interestingly, SciKeras is actually the reason why you are able to pickle the pure-Keras model: we monkey patch The issue itself stems from pickling the callback. I'll have to think a bit on what can be done about this, but I hope that at least sheds some light on the issue for now. |
For what it's worth, The actual unpicklable object is |
@joooeey The quickest solution to your problem is going to be to pass the callback to SciKeras as a fit kwarg: from pickle import dumps
from numpy import random
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
from scikeras.wrappers import KerasClassifier
# %% shared data
X = random.random((10, 6))
y = random.randint(2, size=10)
def build_fn():
"""Build sequential neural network."""
model = Sequential()
model.add(Dense(30, activation="relu", input_shape=(6, )))
model.add(Dense(20, activation="relu"))
model.add(Dense(1, activation="sigmoid"))
model.compile(
optimizer="rmsprop",
loss="binary_crossentropy",
)
return model
X = random.random((10, 6))
y = random.randint(2, size=10)
clf = KerasClassifier(
model=build_fn,
epochs=5,
validation_split=0.1,
)
clf = clf.fit(X, y, callbacks=[TensorBoard("testlogs")])
dumps(clf) Note however that:
|
@stsievert I think this use case is the nail in the coffin for having to keep around |
Do you have a reference issue/PR? I presume you mean passing parameters/keyword arguments through |
We discussed the topic several times previously, eg. #198 , #159 (comment) and #138
Yes, exactly. What I mean that we'll have to keep both constructor parameters (sklearn style) and fit/predict **kwargs (keras style) around, and fully support both. In practice I just think this means removing any wording/warnings around kwargs deprecation that are left, and adding support to partial_fit. The sklearn constructor parameters are necessary for the sklearn ecosystem to work (eg dask-ml hyperparameter tuning), and wherever possible we should encourage it, but certain Keras/TF use cases (in particular the one presented in this PR) simply aren't compatibile with sklearn style constructor parameters and require **kwargs. |
Yeah, this is good motivation to keep that behavior (passing keyword arguments through to Keras.fit). I think that behavior should be strongly discouraged. |
Unfortunately in my real code I have the fit classifier in a pipeline. So now I'm using This raises UserWarning: Passing estimator parameters as keyword arguments (aka as `**kwargs`) to `fit` is not supported by the Scikit-Learn API, and will be removed in a future version of SciKeras.
To resolve this issue, either set these parameters in the constructor (e.g., `est = BaseWrapper(..., foo=bar)`) or via `set_params` (e.g., `est.set_params(foo=bar)`). The following parameters were passed to `fit`:
`callbacks=[<tensorflow.python.keras.callbacks.TensorBoard object at 0x7f539c414460>]`
More detail is available at https://www.adriangb.com/scikeras/migration.html#variable-keyword-arguments-in-fit-and-predict By the way, the link in the warning doesn't work. |
You can safely ignore that warning, as per above we'll probably remove it in the future. Sorry for any confusion this may cause.
You're right, thank you for catching that. This is the correct link, and I'll update the warning (or remove it): https://www.adriangb.com/scikeras/stable/migration.html#variable-keyword-arguments-in-fit-and-predict |
Yea this works for me now. |
Awesome, I'm glad we found you a solution, even if its not ideal. Like I said above, we will probably disable those warnings so that this API will be more straightforward to use going forward. Your feedback has been very valuable, so thank you for the issue and bearing with me during troubleshooting. |
Description of the problem
I was excited about
scikeras
because it can interface withsklearn
and the models can supposedly be pickled. Unfortunatelyscikeras.KerasClassifier
can't be pickled when both of the following conditions are fulfilled:KerasClassifier
includes a callback toTensorBoard
.The equivalent neural network from Keras can be pickled without issue.
Minimum, Complete, Verifiable Example
Stack Trace
The last line of the
# %% scikeras classifier [break]
block raises:Versions
0.3.3
2.4.1
3.9.5
The text was updated successfully, but these errors were encountered: