diff --git a/tfx/orchestration/portable/launcher.py b/tfx/orchestration/portable/launcher.py index e6de68315e..b17223c19e 100644 --- a/tfx/orchestration/portable/launcher.py +++ b/tfx/orchestration/portable/launcher.py @@ -149,10 +149,12 @@ def __init__( executor_spec: Optional[message.Message] = None, custom_driver_spec: Optional[message.Message] = None, platform_config: Optional[message.Message] = None, - custom_executor_operators: Optional[Dict[Any, - Type[ExecutorOperator]]] = None, - custom_driver_operators: Optional[Dict[Any, - Type[DriverOperator]]] = None): + custom_executor_operators: Optional[ + Dict[Any, Type[ExecutorOperator]] + ] = None, + custom_driver_operators: Optional[Dict[Any, Type[DriverOperator]]] = None, + creds: Optional[grpc.ChannelCredentials] = None, + ): """Initializes a Launcher. Args: @@ -174,6 +176,7 @@ def __init__( ExecutorOperation implementation. custom_driver_operators: a map of ExecutableSpec to its DriverOperator implementation. + creds: The credentials to use for the execution watcher. Raises: ValueError: when component and component_config are not launchable by the @@ -191,6 +194,7 @@ def __init__( self._driver_operators = {} self._driver_operators.update(DEFAULT_DRIVER_OPERATORS) self._driver_operators.update(custom_driver_operators or {}) + self._creds = creds or grpc.local_server_credentials() self._executor_operator = None if executor_spec: @@ -600,7 +604,8 @@ def launch(self) -> Optional[data_types.ExecutionInfo]: port=portpicker.pick_unused_port(), mlmd_connection=self._mlmd_connection, execution=execution_preparation_result.execution_metadata, - creds=grpc.local_server_credentials()) + creds=self._creds, + ) self._executor_operator.with_execution_watcher( executor_watcher.address) executor_watcher.start()