Skip to content

Commit

Permalink
Mark PipelineInputs channels as optional if their wrapped channel is …
Browse files Browse the repository at this point in the history
…optional.

PiperOrigin-RevId: 623541282
  • Loading branch information
kmonte authored and tfx-copybara committed Apr 10, 2024
1 parent 3064492 commit 84f1e4a
Show file tree
Hide file tree
Showing 5 changed files with 565 additions and 8 deletions.
4 changes: 4 additions & 0 deletions tfx/dsl/compiler/node_inputs_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tfx.dsl.placeholder import artifact_placeholder
from tfx.dsl.placeholder import placeholder
from tfx.orchestration import data_types_utils
from tfx.orchestration import pipeline
from tfx.proto.orchestration import metadata_pb2
from tfx.proto.orchestration import pipeline_pb2
from tfx.types import channel as channel_types
Expand Down Expand Up @@ -439,6 +440,9 @@ def compile_node_inputs(
for input_key, channel in tfx_node.inputs.items():
if compiler_utils.is_resolver(tfx_node):
min_count = 0
elif isinstance(tfx_node, pipeline.Pipeline):
pipeline_inputs_channel = tfx_node.inputs[input_key]
min_count = 0 if pipeline_inputs_channel.is_optional else 1
elif isinstance(tfx_node, base_component.BaseComponent):
spec_param = tfx_node.spec.INPUTS[input_key]
if (
Expand Down
12 changes: 8 additions & 4 deletions tfx/dsl/compiler/testdata/composable_pipeline_input_v2_ir.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ nodes {
}
}
execution_options {
caching_options {}
caching_options {
}
}
}
}
Expand Down Expand Up @@ -1169,7 +1170,8 @@ nodes {
upstream_nodes: "data-ingestion-pipeline"
downstream_nodes: "Trainer"
execution_options {
caching_options {}
caching_options {
}
strategy: LAZILY_ALL_UPSTREAM_NODES_SUCCEEDED
max_execution_retries: 10
}
Expand Down Expand Up @@ -2206,7 +2208,8 @@ nodes {
downstream_nodes: "Pusher"
downstream_nodes: "infra-validator-pipeline"
execution_options {
caching_options {}
caching_options {
}
}
}
}
Expand Down Expand Up @@ -2507,7 +2510,8 @@ nodes {
upstream_nodes: "validate-and-push-pipeline_begin"
downstream_nodes: "InfraValidator"
execution_options {
caching_options {}
caching_options {
}
}
}
}
Expand Down
18 changes: 17 additions & 1 deletion tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,28 @@ def __init__(self):


def create_test_pipeline():
"""Creaters a pipeline with optional and allow_empty channels."""
upstream_component = UpstreamComponent()
my_component = MyComponent(
mandatory=upstream_component.outputs['first_model'],
optional_but_needed=upstream_component.outputs['second_model'],
optional_and_not_needed=upstream_component.outputs['third_model'])
p_in = pipeline.PipelineInputs({
'mandatory': upstream_component.outputs['first_model'],
'optional': upstream_component.outputs['second_model'].as_optional(),
})
subpipeline_component = MyComponent(
mandatory=p_in['mandatory'],
optional_but_needed=p_in['optional'],
)
subpipeline = pipeline.Pipeline(
pipeline_name='subpipeline',
pipeline_root=_pipeline_root,
components=[subpipeline_component],
inputs=p_in,
)
return pipeline.Pipeline(
pipeline_name=_pipeline_name,
pipeline_root=_pipeline_root,
components=[upstream_component, my_component])
components=[upstream_component, my_component, subpipeline],
)
Loading

0 comments on commit 84f1e4a

Please sign in to comment.