diff --git a/tfx/dsl/compiler/node_inputs_compiler.py b/tfx/dsl/compiler/node_inputs_compiler.py index 379e4fe058..6a1d2bf4ce 100644 --- a/tfx/dsl/compiler/node_inputs_compiler.py +++ b/tfx/dsl/compiler/node_inputs_compiler.py @@ -26,6 +26,7 @@ from tfx.dsl.placeholder import artifact_placeholder from tfx.dsl.placeholder import placeholder from tfx.orchestration import data_types_utils +from tfx.orchestration import pipeline from tfx.proto.orchestration import metadata_pb2 from tfx.proto.orchestration import pipeline_pb2 from tfx.types import channel as channel_types @@ -439,6 +440,9 @@ def compile_node_inputs( for input_key, channel in tfx_node.inputs.items(): if compiler_utils.is_resolver(tfx_node): min_count = 0 + elif isinstance(tfx_node, pipeline.Pipeline): + pipeline_inputs_channel = tfx_node.inputs[input_key] + min_count = 0 if pipeline_inputs_channel.is_optional else 1 elif isinstance(tfx_node, base_component.BaseComponent): spec_param = tfx_node.spec.INPUTS[input_key] if ( diff --git a/tfx/dsl/compiler/testdata/composable_pipeline_input_v2_ir.pbtxt b/tfx/dsl/compiler/testdata/composable_pipeline_input_v2_ir.pbtxt index e6e4ca61d9..b257611d5c 100644 --- a/tfx/dsl/compiler/testdata/composable_pipeline_input_v2_ir.pbtxt +++ b/tfx/dsl/compiler/testdata/composable_pipeline_input_v2_ir.pbtxt @@ -82,7 +82,8 @@ nodes { } } execution_options { - caching_options {} + caching_options { + } } } } @@ -1169,7 +1170,8 @@ nodes { upstream_nodes: "data-ingestion-pipeline" downstream_nodes: "Trainer" execution_options { - caching_options {} + caching_options { + } strategy: LAZILY_ALL_UPSTREAM_NODES_SUCCEEDED max_execution_retries: 10 } @@ -2206,7 +2208,8 @@ nodes { downstream_nodes: "Pusher" downstream_nodes: "infra-validator-pipeline" execution_options { - caching_options {} + caching_options { + } } } } @@ -2507,7 +2510,8 @@ nodes { upstream_nodes: "validate-and-push-pipeline_begin" downstream_nodes: "InfraValidator" execution_options { - caching_options {} + caching_options { + } } } } diff --git a/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline.py b/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline.py index 83377aa062..e9b51b46a4 100644 --- a/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline.py +++ b/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline.py @@ -106,12 +106,28 @@ def __init__(self): def create_test_pipeline(): + """Creaters a pipeline with optional and allow_empty channels.""" upstream_component = UpstreamComponent() my_component = MyComponent( mandatory=upstream_component.outputs['first_model'], optional_but_needed=upstream_component.outputs['second_model'], optional_and_not_needed=upstream_component.outputs['third_model']) + p_in = pipeline.PipelineInputs({ + 'mandatory': upstream_component.outputs['first_model'], + 'optional': upstream_component.outputs['second_model'].as_optional(), + }) + subpipeline_component = MyComponent( + mandatory=p_in['mandatory'], + optional_but_needed=p_in['optional'], + ) + subpipeline = pipeline.Pipeline( + pipeline_name='subpipeline', + pipeline_root=_pipeline_root, + components=[subpipeline_component], + inputs=p_in, + ) return pipeline.Pipeline( pipeline_name=_pipeline_name, pipeline_root=_pipeline_root, - components=[upstream_component, my_component]) + components=[upstream_component, my_component, subpipeline], + ) diff --git a/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline_input_v2_ir.pbtxt b/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline_input_v2_ir.pbtxt index bac3100364..0355afd2f5 100644 --- a/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline_input_v2_ir.pbtxt +++ b/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline_input_v2_ir.pbtxt @@ -84,6 +84,7 @@ nodes { } } downstream_nodes: "MyComponent" + downstream_nodes: "subpipeline" execution_options { caching_options { } @@ -285,6 +286,538 @@ nodes { } } } +nodes { + sub_pipeline { + pipeline_info { + id: "subpipeline" + } + nodes { + pipeline_node { + node_info { + type { + name: "tfx.orchestration.pipeline.Pipeline_begin" + } + id: "subpipeline_begin" + } + contexts { + contexts { + type { + name: "pipeline" + } + name { + field_value { + string_value: "subpipeline" + } + } + } + contexts { + type { + name: "pipeline_run" + } + name { + structural_runtime_parameter { + parts { + constant_value: "subpipeline_" + } + parts { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + } + } + contexts { + type { + name: "pipeline" + } + name { + field_value { + string_value: "optional_and_allow_empty_pipeline" + } + } + } + contexts { + type { + name: "pipeline_run" + } + name { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + contexts { + type { + name: "node" + } + name { + field_value { + string_value: "subpipeline.subpipeline_begin" + } + } + } + } + inputs { + inputs { + key: "mandatory" + value { + channels { + producer_node_query { + id: "UpstreamComponent" + } + context_queries { + type { + name: "pipeline" + } + name { + field_value { + string_value: "optional_and_allow_empty_pipeline" + } + } + } + context_queries { + type { + name: "pipeline_run" + } + name { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + context_queries { + type { + name: "node" + } + name { + field_value { + string_value: "optional_and_allow_empty_pipeline.UpstreamComponent" + } + } + } + artifact_query { + type { + name: "Model" + base_type: MODEL + } + } + output_key: "first_model" + } + min_count: 1 + } + } + inputs { + key: "optional" + value { + channels { + context_queries { + type { + name: "pipeline" + } + name { + field_value { + string_value: "optional_and_allow_empty_pipeline" + } + } + } + context_queries { + type { + name: "node" + } + name { + field_value { + string_value: "optional_and_allow_empty_pipeline.UpstreamComponent" + } + } + } + artifact_query { + type { + name: "Model" + base_type: MODEL + } + } + output_key: "second_model" + } + } + } + } + outputs { + outputs { + key: "mandatory" + value { + artifact_spec { + type { + name: "Model" + base_type: MODEL + } + } + } + } + outputs { + key: "optional" + value { + artifact_spec { + type { + name: "Model" + base_type: MODEL + } + } + } + } + } + upstream_nodes: "UpstreamComponent" + downstream_nodes: "MyComponent" + execution_options { + caching_options { + } + } + } + } + nodes { + pipeline_node { + node_info { + type { + name: "tfx.dsl.compiler.testdata.optional_and_allow_empty_pipeline.MyComponent" + } + id: "MyComponent" + } + contexts { + contexts { + type { + name: "pipeline" + } + name { + field_value { + string_value: "subpipeline" + } + } + } + contexts { + type { + name: "pipeline_run" + } + name { + structural_runtime_parameter { + parts { + constant_value: "subpipeline_" + } + parts { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + } + } + contexts { + type { + name: "pipeline" + } + name { + field_value { + string_value: "optional_and_allow_empty_pipeline" + } + } + } + contexts { + type { + name: "pipeline_run" + } + name { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + contexts { + type { + name: "node" + } + name { + field_value { + string_value: "subpipeline.MyComponent" + } + } + } + } + inputs { + inputs { + key: "mandatory" + value { + channels { + producer_node_query { + id: "subpipeline_begin" + } + context_queries { + type { + name: "pipeline" + } + name { + field_value { + string_value: "subpipeline" + } + } + } + context_queries { + type { + name: "pipeline_run" + } + name { + structural_runtime_parameter { + parts { + constant_value: "subpipeline_" + } + parts { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + } + } + context_queries { + type { + name: "pipeline" + } + name { + field_value { + string_value: "optional_and_allow_empty_pipeline" + } + } + } + context_queries { + type { + name: "pipeline_run" + } + name { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + context_queries { + type { + name: "node" + } + name { + field_value { + string_value: "subpipeline.subpipeline_begin" + } + } + } + artifact_query { + type { + name: "Model" + base_type: MODEL + } + } + output_key: "mandatory" + } + min_count: 1 + } + } + inputs { + key: "optional_but_needed" + value { + channels { + producer_node_query { + id: "subpipeline_begin" + } + context_queries { + type { + name: "pipeline" + } + name { + field_value { + string_value: "subpipeline" + } + } + } + context_queries { + type { + name: "pipeline_run" + } + name { + structural_runtime_parameter { + parts { + constant_value: "subpipeline_" + } + parts { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + } + } + context_queries { + type { + name: "pipeline" + } + name { + field_value { + string_value: "optional_and_allow_empty_pipeline" + } + } + } + context_queries { + type { + name: "pipeline_run" + } + name { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + context_queries { + type { + name: "node" + } + name { + field_value { + string_value: "subpipeline.subpipeline_begin" + } + } + } + artifact_query { + type { + name: "Model" + base_type: MODEL + } + } + output_key: "optional" + } + } + } + } + upstream_nodes: "subpipeline_begin" + execution_options { + caching_options { + } + } + } + } + nodes { + pipeline_node { + node_info { + type { + name: "tfx.orchestration.pipeline.Pipeline_end" + } + id: "subpipeline_end" + } + contexts { + contexts { + type { + name: "pipeline" + } + name { + field_value { + string_value: "subpipeline" + } + } + } + contexts { + type { + name: "pipeline_run" + } + name { + structural_runtime_parameter { + parts { + constant_value: "subpipeline_" + } + parts { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + } + } + contexts { + type { + name: "pipeline" + } + name { + field_value { + string_value: "optional_and_allow_empty_pipeline" + } + } + } + contexts { + type { + name: "pipeline_run" + } + name { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + contexts { + type { + name: "node" + } + name { + field_value { + string_value: "subpipeline.subpipeline_end" + } + } + } + } + } + } + runtime_spec { + pipeline_root { + runtime_parameter { + name: "pipeline-root" + type: STRING + default_value { + string_value: "pipeline/optional_and_allow_empty_pipeline" + } + } + } + pipeline_run_id { + structural_runtime_parameter { + parts { + constant_value: "subpipeline_" + } + parts { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + } + } + execution_mode: SYNC + deployment_config { + [type.googleapis.com/tfx.orchestration.IntermediateDeploymentConfig] { + executor_specs { + key: "MyComponent" + value { + [type.googleapis.com/tfx.orchestration.executable_spec.PythonClassExecutableSpec] { + class_path: "tfx.dsl.compiler.testdata.optional_and_allow_empty_pipeline.Executor" + } + } + } + } + } + } +} runtime_spec { pipeline_root { runtime_parameter { diff --git a/tfx/types/channel.py b/tfx/types/channel.py index 91e6abbfe3..f6b3fe6346 100644 --- a/tfx/types/channel.py +++ b/tfx/types/channel.py @@ -120,7 +120,7 @@ class BaseChannel(abc.ABC, Generic[_AT]): set. """ - def __init__(self, type: Type[_AT]): # pylint: disable=redefined-builtin + def __init__(self, type: Type[_AT], is_optional: Optional[bool] = None): # pylint: disable=redefined-builtin if not _is_artifact_type(type): raise ValueError( 'Argument "type" of BaseChannel constructor must be a subclass of ' @@ -128,7 +128,7 @@ def __init__(self, type: Type[_AT]): # pylint: disable=redefined-builtin self._artifact_type = type self._input_trigger = None self._original_channel = None - self._is_optional = None + self._is_optional = is_optional @property def is_optional(self) -> Optional[bool]: @@ -663,7 +663,7 @@ class PipelineInputChannel(BaseChannel): """ def __init__(self, wrapped: BaseChannel, output_key: str): - super().__init__(type=wrapped.type) + super().__init__(type=wrapped.type, is_optional=wrapped.is_optional) self._wrapped = wrapped self._output_key = output_key self._pipeline = None