-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: Composable input/output pipeline #234
base: master
Are you sure you want to change the base?
Conversation
Codecov Report
@@ Coverage Diff @@
## master #234 +/- ##
=======================================
Coverage 98.90% 98.90%
=======================================
Files 6 6
Lines 728 728
=======================================
Hits 720 720
Misses 8 8 Continue to review full report at Codecov.
|
📝 Docs preview for commit 98d279b at: https://www.adriangb.com/scikeras/refs/pull/234/merge/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this RFC – this is easier to review than a PR. Here are some questions:
- What issue are you trying to solve with this RFC? What's an illustration of that issue?
- Is that issue user-facing?
- How does this RFC solve that (user-facing?) issue?
If the data comes in as a `tf.data.Dataset`, the first pipeline is skipped. If not, it is run and the output is converted to a `Dataset`. | ||
The second pipeline is then run in all cases. | ||
|
||
These pipelines will consist of chained transformers implementing a Scikit-Learn-like interface, but without being restricted to the exact Scikit-Learn API. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but without being restricted to the exact Scikit-Learn API.
By "exact Scikit-Learn API," you mean "the exact names of Scikit-learn transformers," right? And the proposed "Scikit-learn-like interface" is public, correct?
If the answers to both of those questions are affirmative, 👎 to that interface. How can this RFC be reworked to conform with the "exact Scikit-Learn API"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By "exact Scikit-Learn API" I mean roughly:
- having a
fit
method with the signaturefit(X, y=None) -> self
- having a transform method with the signature
transform(X) -> X'
- an
inverse_transform
method with the signatureinverse_transform(X') -> X
.
And no other methods (that are part of the API).
Alsoinverse_transform(transform(X)) == X
is not a requirement of the API, but it certainly is the spirit of it (and how most sklearn transformers are implemented, when it makes sense to do so).
And the proposed "Scikit-learn-like interface" is public, correct?
Yes, it would be public and consists of ArrayTransformer
and DatasetTransformer
.
These can be base classes, protocols/interfaces or just duck-typing (implementation detail).
The idea is that users could take our default data validation/transformations and mix/match with their own for use cases like multi-output models or multi-output class reweighting.
This proposed interface violates the sklearn interface in several ways:
- by passing
X
,y
andsample_weight
inArrayTransformer
- by passing a
tf.data.Dataset
object inDatasetTransformer
- because these really aren't inverse transformations: the forward/input transformation transforms X & y, the output transformation transforms
y'
(y predictions) and/ory
itself. Soinverse_transform(transform(X)) != X
. - by allowing side effects to be applied to the
model
(this is an addition to the API, it doesn't strictly break the implementation, but it does break the spirit).
How can this RFC be reworked to conform with the "exact Scikit-Learn API"?
I don't think it can be easily reworked: there are too many limitations in the sklearn transformer API (and the authors have said as much themselves).
The good news is that the Scikit-Learn API is a strict subset of this API, so we could provide wrappers to convert the Scikit-Learn transformer API into this one (eg. by specifying what X
is to your transformer and dispatching the methods).
But let me ask: why is it important to adhere to the Scikit-Learn API? I know there are good reasons, I just want understand which you are thinking of.
My thoughts are that if we can't interoperate directly (i.e. use make_pipeline
and other facilities of sklearn), there is not much value in being semi-compatible. I also think there aren't that many things (beyond pipelines) that would be useful here. But this may be short-sighted, I'm open to other opinions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This library has one job: to bring the Scikit-learn API to Keras. That means there needs to be a really good reason to break the Scikit-learn API.
That reason needs to consider alternative implementations, including the one that exactly follows the Scikit-learn API. I imagine some questions that need to be answered include the following:
- What affect does the implementation have on developers?
- What affect does the implementation have on the user?
- What specific issues make the implementation inappropriate?
... | ||
``` | ||
|
||
## Issues this can potentially resolve |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These issues seem to be around modular validation/transforming. This RFC proposes one solution. What are other solutions? Why does this RFC represent the best solution?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are other solutions?
This is, of course, a great question. I think general solutions would be variations of the same idea, perhaps less structured (eg. hardcoding that tf.data.Dataset
inputs should skip BaseWrapper._validate_data
). I'd have to think a bit to see if I can come up with any other structured approaches that might also solve the issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not have two Scikit-learn transformers, one for validation and one to change the data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would need:
- Interface for validation using array-like data
- Interface for data changing using array-like data
- Interface for validation using Dataset
- Interface for data changing using array-like data
I think this is too many interfaces. The only difference between a "validation" and a "transformation" is that the transformation needs to return the data, but the validation does not. So by having the validation return the data we can collapse those two concepts into one. Validation transformers simply inspect the data but do not modify it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What use case/issue motivates having separate classes for Dataset/array-like inputs? Why not collapse it into one class?
Why is 4 interfaces "too many"? Why is the current framework with target_encoder_
and feature_encoder_
not sufficient?
I'm asking questions to get answers encoded in the RFC (an RFC is a really good idea to encode these decisions).
Co-authored-by: Scott Sievert <stsievert@users.noreply.github.com>
Co-authored-by: Scott Sievert <stsievert@users.noreply.github.com>
Co-authored-by: Scott Sievert <stsievert@users.noreply.github.com>
Co-authored-by: Scott Sievert <stsievert@users.noreply.github.com>
Co-authored-by: Scott Sievert <stsievert@users.noreply.github.com>
Co-authored-by: Scott Sievert <stsievert@users.noreply.github.com>
Good to know it helps! And thank you for reviewing.
This solves some user-facing issues, as well as allowing us to rework things internally to be cleaner and more flexible for users. There are several different issues, but I think the clearest one is #160. The general user request is to be able to use We could hack this together as is (by hardcoding bypasses to The other issues linked are generally around the theme of more validations (#106, #143) or figuring out how to clean up our current validations (#111, #209). These two issues are intertwined because:
This RFC resolves this by providing a unified interface for these validations/transformations that is modular, composable and public. I envision it helping our users (since they can disable or add validation/transformation steps) as well as us (the developers) because we can organize our default validation/transformations better. |
If the data comes in as a `tf.data.Dataset`, the first pipeline is skipped. If not, it is run and the output is converted to a `Dataset`. | ||
The second pipeline is then run in all cases. | ||
|
||
These pipelines will consist of chained transformers implementing a Scikit-Learn-like interface, but without being restricted to the exact Scikit-Learn API. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This library has one job: to bring the Scikit-learn API to Keras. That means there needs to be a really good reason to break the Scikit-learn API.
That reason needs to consider alternative implementations, including the one that exactly follows the Scikit-learn API. I imagine some questions that need to be answered include the following:
- What affect does the implementation have on developers?
- What affect does the implementation have on the user?
- What specific issues make the implementation inappropriate?
|
||
class ArrayTransformer: | ||
|
||
def set_model(self, model: "BaseWrapper") -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👎
Why not follow the Scikit-Learn API?
class ArrayTransformer(BaseEstimator):
def __init__(self, model=None):
self.model = model
...
a = ArrayTransformer(model=model) # option 0
a = ArrayTransformer().model = model # option 1
a = ArrayTransformer()
a.set_params(model=model) # option 2
Yes, TF uses set_model
. Why should their boilerplate be followed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would require users to give us a class to initialize (this object is being passed to the constructor that creates model
).
I guess we could use parameter routing to allow setting any other parameters. Does this look any better?
class MyTransf:
def __init__(self, model, other):
self.other = other
...
est = BaseWrapper(..., pipeline_param=[MyTransf], pipeline_param__0__other=True)
# or
est = BaseWrapper(..., pipeline_param={"tfname": MyTransf}, pipeline_param__ tfname__other=True) # relies on dict ordering to know what order to run pipeline in
# or
est = BaseWrapper(..., pipeline_param=[("tfname", MyTransf)], pipeline_param__ tfname__other=True) # same as an sklearn pipeline using a tuple to set the name
An alternative would be to require users to bind any other parameters using functools.partial
or something.
def transform_input(self, data: tf.data.Dataset, *, initialize: bool = True) -> tf.data.Dataset: | ||
return data | ||
|
||
def transform_output(self, y_pred_proba: np.ndarray, y: Union[np.ndarray, None]) -> Tuple[np.ndarray, Union[np.ndarray, None]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not rename these functions inverse_transform
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because their signature and usage does not match inverse_transform
's signature. They wouldn't work in an sklearn Pipeline
, nor would they be chainable with sklearn estimators.
... | ||
``` | ||
|
||
## Issues this can potentially resolve |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not have two Scikit-learn transformers, one for validation and one to change the data?
This is an attempt to collect ideas from various issues / PRs into a coherent framework and generalize our current input/output transformers.
@stsievert would you mind taking a look?