Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't pickle trained model with callback to TensorBoard #236

Closed
joooeey opened this issue Jul 9, 2021 · 11 comments · Fixed by #238
Closed

Can't pickle trained model with callback to TensorBoard #236

joooeey opened this issue Jul 9, 2021 · 11 comments · Fixed by #238

Comments

@joooeey
Copy link

joooeey commented Jul 9, 2021

Description of the problem

I was excited about scikeras because it can interface with sklearn and the models can supposedly be pickled. Unfortunately scikeras.KerasClassifier can't be pickled when both of the following conditions are fulfilled:

  • the KerasClassifier includes a callback to TensorBoard.
  • it has been trained

The equivalent neural network from Keras can be pickled without issue.

Minimum, Complete, Verifiable Example

from joblib import dump
# from pickle import dump  # causes the same problem
from numpy import random

from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
from scikeras.wrappers import KerasClassifier


# %% shared data

X = random.random((10, 6))
y = random.randint(2, size=10)

def build_fn():
    """Build sequential neural network."""
    model = Sequential()
    model.add(Dense(30, activation="relu", input_shape=(6, )))
    model.add(Dense(20, activation="relu"))
    model.add(Dense(1, activation="sigmoid"))
    
    model.compile(
        optimizer="rmsprop",
        loss="binary_crossentropy",
    )
    
    return model

X = random.random((10, 6))
y = random.randint(2, size=10)

# %% scikeras classifier [breaks]

clf = KerasClassifier(
    model=build_fn,
    epochs=5,
    validation_split=0.1,
    callbacks=[TensorBoard("testlogs")],  # won't break without this line
)

clf = clf.fit(X, y)  # won't break without this line

dump(clf, open("test_scikeras.pkl", "wb"))  # raises InvalidArgumentError


# %% same classifier in pure tf.keras [works]

model = build_fn()

model.fit(
    X,
    y,
    epochs=5,
    validation_split=0.1,
    callbacks=[TensorBoard("testlogs")]
)

dump(model, open("test_keras.pkl", "wb"))  # works

Stack Trace

The last line of the # %% scikeras classifier [break] block raises:

Traceback (most recent call last):

  File "/home/lukas/Desktop/tensorboard_temp.py", line 52, in <module>
    dump(clf, open("test_scikeras.pkl", "wb"))  # raises InvalidArgumentError

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/numpy_pickle.py", line 482, in dump
    NumpyPickler(filename, protocol=protocol).dump(value)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 487, in dump
    self.save(obj)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/numpy_pickle.py", line 282, in save
    return Pickler.save(self, obj)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 603, in save
    self.save_reduce(obj=obj, *rv)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 717, in save_reduce
    save(state)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/numpy_pickle.py", line 282, in save
    return Pickler.save(self, obj)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 560, in save
    f(self, obj)  # Call unbound method with explicit self

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 971, in save_dict
    self._batch_setitems(obj.items())

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 997, in _batch_setitems
    save(v)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/numpy_pickle.py", line 282, in save
    return Pickler.save(self, obj)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 560, in save
    f(self, obj)  # Call unbound method with explicit self

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 931, in save_list
    self._batch_appends(obj)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 958, in _batch_appends
    save(tmp[0])

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/numpy_pickle.py", line 282, in save
    return Pickler.save(self, obj)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 603, in save
    self.save_reduce(obj=obj, *rv)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 717, in save_reduce
    save(state)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/numpy_pickle.py", line 282, in save
    return Pickler.save(self, obj)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 560, in save
    f(self, obj)  # Call unbound method with explicit self

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 971, in save_dict
    self._batch_setitems(obj.items())

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 997, in _batch_setitems
    save(v)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/numpy_pickle.py", line 282, in save
    return Pickler.save(self, obj)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 560, in save
    f(self, obj)  # Call unbound method with explicit self

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 971, in save_dict
    self._batch_setitems(obj.items())

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 997, in _batch_setitems
    save(v)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/numpy_pickle.py", line 282, in save
    return Pickler.save(self, obj)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 603, in save
    self.save_reduce(obj=obj, *rv)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 717, in save_reduce
    save(state)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/numpy_pickle.py", line 282, in save
    return Pickler.save(self, obj)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 560, in save
    f(self, obj)  # Call unbound method with explicit self

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 971, in save_dict
    self._batch_setitems(obj.items())

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 997, in _batch_setitems
    save(v)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/joblib/numpy_pickle.py", line 282, in save
    return Pickler.save(self, obj)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/pickle.py", line 578, in save
    rv = reduce(self.proto)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 1000, in __reduce__
    return convert_to_tensor, (self._numpy(),)

  File "/home/lukas/anaconda3/envs/tf_gpu/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 1039, in _numpy
    six.raise_from(core._status_to_exception(e.code, e.message), None)  # pylint: disable=protected-access

  File "<string>", line 3, in raise_from

InvalidArgumentError: Cannot convert a Tensor of dtype resource to a NumPy array.

Versions

  • SciKeras 0.3.3
  • TensorFlow 2.4.1
  • Python 3.9.5
@adriangb
Copy link
Owner

adriangb commented Jul 9, 2021

Thank you for the clean reproducible example.

Interestingly, SciKeras is actually the reason why you are able to pickle the pure-Keras model: we monkey patch tf.keras.Model to make it packable (here). If you remove the SciKeras import, that will fail with something along the lines of can't pickle weakref object.

The issue itself stems from pickling the callback.
Unfortunately, many things in TensorFlow aren't picklable using standard Python pickling facilities.
So SciKeras uses TensorFlow's own serialization support to serialize models.
But TensorFlow's serialization support is limited and rigid. For example, it can serialize an entire Model, but not Callback or Optimizer instances.
Generally this is OK because things like optimizer instances are stored as part of the Keras model itself, and so TensorFlow knows how to serialize it.
But models don't hold references to callbacks, they are only passed in as fit/predict arguments.
So SciKeras has to hold the reference to callbacks in order to support stateful callbacks.
Which of course causes an issue when pickling (I had not thought of this when I implemented callback support, so thank you for bringing it up).

I'll have to think a bit on what can be done about this, but I hope that at least sheds some light on the issue for now.

@adriangb
Copy link
Owner

adriangb commented Jul 9, 2021

For what it's worth, The actual unpicklable object is TensorBoard._writers. Deleting it makes the callback picklable.

@adriangb
Copy link
Owner

@joooeey The quickest solution to your problem is going to be to pass the callback to SciKeras as a fit kwarg:

from pickle import dumps
from numpy import random

from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
from scikeras.wrappers import KerasClassifier


# %% shared data

X = random.random((10, 6))
y = random.randint(2, size=10)

def build_fn():
    """Build sequential neural network."""
    model = Sequential()
    model.add(Dense(30, activation="relu", input_shape=(6, )))
    model.add(Dense(20, activation="relu"))
    model.add(Dense(1, activation="sigmoid"))
    
    model.compile(
        optimizer="rmsprop",
        loss="binary_crossentropy",
    )
    
    return model

X = random.random((10, 6))
y = random.randint(2, size=10)

clf = KerasClassifier(
    model=build_fn,
    epochs=5,
    validation_split=0.1,
)

clf = clf.fit(X, y, callbacks=[TensorBoard("testlogs")])

dumps(clf)

Note however that:

  1. The callback won't be serialized / saved if you pickle/unpickle the model.
  2. You can't hyperparameter tune it (not that it would make sense in this case). But you can pass it to GridSearchCV and other hyper parameter tuning tools if they support **fit_args
  3. You can't use partial_fit (and hence Dask) since our partial_fit doesn't support passing arbitrary arguments.

@adriangb
Copy link
Owner

@stsievert I think this use case is the nail in the coffin for having to keep around **kwargs

@stsievert
Copy link
Collaborator

the nail in the coffin for having to keep around **kwargs

Do you have a reference issue/PR? I presume you mean passing parameters/keyword arguments through fit instead of always specifying those parameters at initialization. Is that correct?

@adriangb
Copy link
Owner

adriangb commented Jul 10, 2021

Do you have a reference issue/PR?

We discussed the topic several times previously, eg. #198 , #159 (comment) and #138

you mean passing parameters/keyword arguments through fit instead of always specifying those parameters at initialization

Yes, exactly.

What I mean that we'll have to keep both constructor parameters (sklearn style) and fit/predict **kwargs (keras style) around, and fully support both. In practice I just think this means removing any wording/warnings around kwargs deprecation that are left, and adding support to partial_fit.

The sklearn constructor parameters are necessary for the sklearn ecosystem to work (eg dask-ml hyperparameter tuning), and wherever possible we should encourage it, but certain Keras/TF use cases (in particular the one presented in this PR) simply aren't compatibile with sklearn style constructor parameters and require **kwargs.

@stsievert
Copy link
Collaborator

Yeah, this is good motivation to keep that behavior (passing keyword arguments through to Keras.fit).

I think that behavior should be strongly discouraged.

@joooeey
Copy link
Author

joooeey commented Jul 12, 2021

@joooeey The quickest solution to your problem is going to be to pass the callback to SciKeras as a fit kwarg:
[...]

Unfortunately in my real code I have the fit classifier in a pipeline. So now I'm using my_pipeline.fit(X, y, kerasclassifier__callbacks=[Tensorboard()])

This raises

UserWarning: Passing estimator parameters as keyword arguments (aka as `**kwargs`) to `fit` is not supported by the Scikit-Learn API, and will be removed in a future version of SciKeras.

To resolve this issue, either set these parameters in the constructor (e.g., `est = BaseWrapper(..., foo=bar)`) or via `set_params` (e.g., `est.set_params(foo=bar)`). The following parameters were passed to `fit`:

`callbacks=[<tensorflow.python.keras.callbacks.TensorBoard object at 0x7f539c414460>]`

More detail is available at https://www.adriangb.com/scikeras/migration.html#variable-keyword-arguments-in-fit-and-predict

By the way, the link in the warning doesn't work.

@adriangb
Copy link
Owner

You can safely ignore that warning, as per above we'll probably remove it in the future. Sorry for any confusion this may cause.
Other than the warning, does that work for you?

By the way, the link in the warning doesn't work.

You're right, thank you for catching that. This is the correct link, and I'll update the warning (or remove it): https://www.adriangb.com/scikeras/stable/migration.html#variable-keyword-arguments-in-fit-and-predict

@joooeey
Copy link
Author

joooeey commented Jul 14, 2021

Yea this works for me now.

@adriangb
Copy link
Owner

adriangb commented Jul 14, 2021

Awesome, I'm glad we found you a solution, even if its not ideal.

Like I said above, we will probably disable those warnings so that this API will be more straightforward to use going forward.

Your feedback has been very valuable, so thank you for the issue and bearing with me during troubleshooting.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants