Skip to content

Commit

Permalink
Update with_* methods on BaseNode and BaseComponent to return Self
Browse files Browse the repository at this point in the history
Doing this perserves the type through `with_` calls, so that `MyComponent().with_id("foo")` is still a `MyComponent`, not a `BaseNode`

PiperOrigin-RevId: 617356462
  • Loading branch information
kmonte authored and tfx-copybara committed Mar 20, 2024
1 parent b859df4 commit 3151662
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
6 changes: 5 additions & 1 deletion tfx/dsl/components/base/base_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tfx.types.system_executions import SystemExecution
from tfx.utils import abc_utils
from tfx.utils import doc_controls
import typing_extensions

from google.protobuf import message

Expand Down Expand Up @@ -150,8 +151,11 @@ def _validate_spec(self, spec):
'got %s instead.') %
(self.__class__, self.__class__.SPEC_CLASS, spec))

# TODO(kmonte): Update this to Self once we're on 3.11 everywhere.
@doc_controls.do_not_doc_in_subclasses
def with_platform_config(self, config: message.Message) -> 'BaseComponent':
def with_platform_config(
self, config: message.Message
) -> typing_extensions.Self:
"""Attaches a proto-form platform config to a component.
The config will be a per-node platform-specific config.
Expand Down
10 changes: 6 additions & 4 deletions tfx/dsl/components/base/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tfx.utils import doc_controls
from tfx.utils import json_utils
from tfx.utils import name_utils
import typing_extensions


def _abstract_property() -> Any:
Expand Down Expand Up @@ -128,8 +129,9 @@ def component_id(self) -> str:
def id(self, id: str) -> None: # pylint: disable=redefined-builtin
self._id = id

# TODO(kmonte): Update this to Self once we're on 3.11 everywhere
@doc_controls.do_not_doc_in_subclasses
def with_id(self, id: str) -> 'BaseNode': # pylint: disable=redefined-builtin
def with_id(self, id: str) -> typing_extensions.Self: # pylint: disable=redefined-builtin
self._id = id
return self

Expand Down Expand Up @@ -166,10 +168,10 @@ def node_execution_options(
):
self._node_execution_options = copy.deepcopy(node_execution_options)

# TODO(kmonte): Update this to Self once we're on 3.11 everywhere
def with_node_execution_options(
self,
node_execution_options: utils.NodeExecutionOptions
) -> 'BaseNode':
self, node_execution_options: utils.NodeExecutionOptions
) -> typing_extensions.Self:
self.node_execution_options = node_execution_options
return self

Expand Down

0 comments on commit 3151662

Please sign in to comment.