From 0313e7c5a0a046b3a6240ef99b4f8b97b3acc770 Mon Sep 17 00:00:00 2001 From: kmonte Date: Wed, 17 Apr 2024 10:47:58 -0700 Subject: [PATCH] Fix as_optional() Because we use `id` as a hash and as_optional() creates a new object then this check [1] will not pass, and we'd instead go and fall through to [2], which does not add the pipeline run context. PiperOrigin-RevId: 625736342 --- tfx/dsl/compiler/compiler.py | 87 ++------ tfx/dsl/compiler/compiler_context.py | 2 + tfx/dsl/compiler/node_contexts_compiler.py | 108 +++++++++ .../compiler/node_contexts_compiler_test.py | 157 ++++++++++++++ tfx/dsl/compiler/node_inputs_compiler.py | 90 ++++++-- tfx/dsl/compiler/node_inputs_compiler_test.py | 3 +- .../optional_and_allow_empty_pipeline.py | 16 +- ...and_allow_empty_pipeline_input_v2_ir.pbtxt | 205 ++++++++++++++++++ 8 files changed, 584 insertions(+), 84 deletions(-) create mode 100644 tfx/dsl/compiler/node_contexts_compiler.py create mode 100644 tfx/dsl/compiler/node_contexts_compiler_test.py diff --git a/tfx/dsl/compiler/compiler.py b/tfx/dsl/compiler/compiler.py index 4af95be5af..7e4bf0c97a 100644 --- a/tfx/dsl/compiler/compiler.py +++ b/tfx/dsl/compiler/compiler.py @@ -19,6 +19,7 @@ from tfx.dsl.compiler import compiler_context from tfx.dsl.compiler import compiler_utils from tfx.dsl.compiler import constants +from tfx.dsl.compiler import node_contexts_compiler from tfx.dsl.compiler import node_execution_options_utils from tfx.dsl.compiler import node_inputs_compiler from tfx.dsl.components.base import base_component @@ -56,7 +57,12 @@ def _compile_pipeline_begin_node( # Step 2: Node Context # Inner pipeline's contexts. - _set_node_context(node, pipeline_ctx) + node.contexts.CopyFrom( + node_contexts_compiler.compile_node_contexts( + pipeline_ctx, + node.node_info.id, + ) + ) # Step 3: Node inputs # Pipeline node inputs are stored as the inputs of the PipelineBegin node. @@ -121,7 +127,12 @@ def _compile_pipeline_end_node( # Step 2: Node Context # Inner pipeline's contexts. - _set_node_context(node, pipeline_ctx) + node.contexts.CopyFrom( + node_contexts_compiler.compile_node_contexts( + pipeline_ctx, + node.node_info.id, + ) + ) # Step 3: Node inputs node_inputs_compiler.compile_node_inputs( @@ -194,7 +205,12 @@ def _compile_node( node.node_info.id = tfx_node.id # Step 2: Node Context - _set_node_context(node, pipeline_ctx) + node.contexts.CopyFrom( + node_contexts_compiler.compile_node_contexts( + pipeline_ctx, + node.node_info.id, + ) + ) # Step 3: Node inputs node_inputs_compiler.compile_node_inputs( @@ -386,71 +402,6 @@ def _validate_pipeline(tfx_pipeline: pipeline.Pipeline, raise ValueError("Subpipeline has to be Sync execution mode.") -def _set_node_context(node: pipeline_pb2.PipelineNode, - pipeline_ctx: compiler_context.PipelineContext): - """Compiles the node contexts of a pipeline node.""" - # Context for the pipeline, across pipeline runs. - pipeline_context_pb = node.contexts.contexts.add() - pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME - pipeline_context_pb.name.field_value.string_value = ( - pipeline_ctx.pipeline_info.pipeline_context_name) - - # Context for the current pipeline run. - if pipeline_ctx.is_sync_mode: - pipeline_run_context_pb = node.contexts.contexts.add() - pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME - # TODO(kennethyang): Miragte pipeline run id to structural_runtime_parameter - # To keep existing IR textprotos used in tests unchanged, we only use - # structural_runtime_parameter for subpipelines. After the subpipeline being - # implemented, we will need to migrate normal pipelines to - # structural_runtime_parameter as well for consistency. Similar for below. - if pipeline_ctx.is_subpipeline: - compiler_utils.set_structural_runtime_parameter_pb( - pipeline_run_context_pb.name.structural_runtime_parameter, [ - f"{pipeline_ctx.pipeline_info.pipeline_context_name}_", - (constants.PIPELINE_RUN_ID_PARAMETER_NAME, str) - ]) - else: - compiler_utils.set_runtime_parameter_pb( - pipeline_run_context_pb.name.runtime_parameter, - constants.PIPELINE_RUN_ID_PARAMETER_NAME, str) - - # Contexts inherited from the parent pipelines. - for i, parent_pipeline in enumerate(pipeline_ctx.parent_pipelines[::-1]): - parent_pipeline_context_pb = node.contexts.contexts.add() - parent_pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME - parent_pipeline_context_pb.name.field_value.string_value = ( - parent_pipeline.pipeline_info.pipeline_context_name) - - if parent_pipeline.execution_mode == pipeline.ExecutionMode.SYNC: - pipeline_run_context_pb = node.contexts.contexts.add() - pipeline_run_context_pb.type.name = ( - constants.PIPELINE_RUN_CONTEXT_TYPE_NAME) - - # TODO(kennethyang): Miragte pipeline run id to structural runtime - # parameter for the similar reason mentioned above. - # Use structural runtime parameter to represent pipeline_run_id except - # for the root level pipeline, for backward compatibility. - if i == len(pipeline_ctx.parent_pipelines) - 1: - compiler_utils.set_runtime_parameter_pb( - pipeline_run_context_pb.name.runtime_parameter, - constants.PIPELINE_RUN_ID_PARAMETER_NAME, str) - else: - compiler_utils.set_structural_runtime_parameter_pb( - pipeline_run_context_pb.name.structural_runtime_parameter, [ - f"{parent_pipeline.pipeline_info.pipeline_context_name}_", - (constants.PIPELINE_RUN_ID_PARAMETER_NAME, str) - ]) - - # Context for the node, across pipeline runs. - node_context_pb = node.contexts.contexts.add() - node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME - node_context_pb.name.field_value.string_value = ( - compiler_utils.node_context_name( - pipeline_ctx.pipeline_info.pipeline_context_name, - node.node_info.id)) - - def _set_node_outputs(node: pipeline_pb2.PipelineNode, tfx_node_outputs: Dict[str, types.Channel]): """Compiles the outputs of a pipeline node.""" diff --git a/tfx/dsl/compiler/compiler_context.py b/tfx/dsl/compiler/compiler_context.py index 17193cb4f2..8549d79c2e 100644 --- a/tfx/dsl/compiler/compiler_context.py +++ b/tfx/dsl/compiler/compiler_context.py @@ -55,6 +55,8 @@ def __init__(self, # Mapping from Channel object to compiled Channel proto. self.channels = dict() + self.node_context_protos_cache: dict[str, pipeline_pb2.NodeContexts] = {} + # Node ID -> NodeContext self._node_contexts: Dict[str, NodeContext] = {} diff --git a/tfx/dsl/compiler/node_contexts_compiler.py b/tfx/dsl/compiler/node_contexts_compiler.py new file mode 100644 index 0000000000..73e73ea032 --- /dev/null +++ b/tfx/dsl/compiler/node_contexts_compiler.py @@ -0,0 +1,108 @@ +# Copyright 2024 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Compiles NodeContexts.""" + +from tfx.dsl.compiler import compiler_context +from tfx.dsl.compiler import compiler_utils +from tfx.dsl.compiler import constants +from tfx.orchestration import pipeline +from tfx.proto.orchestration import pipeline_pb2 + + +def compile_node_contexts( + pipeline_ctx: compiler_context.PipelineContext, + node_id: str, +) -> pipeline_pb2.NodeContexts: + """Compiles the node contexts of a pipeline node.""" + + if pipeline_ctx.pipeline_info is None: + return pipeline_pb2.NodeContexts() + if maybe_contexts := pipeline_ctx.node_context_protos_cache.get(node_id): + return maybe_contexts + + node_contexts = pipeline_pb2.NodeContexts() + # Context for the pipeline, across pipeline runs. + pipeline_context_pb = node_contexts.contexts.add() + pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME + pipeline_context_pb.name.field_value.string_value = ( + pipeline_ctx.pipeline_info.pipeline_context_name + ) + + # Context for the current pipeline run. + if pipeline_ctx.is_sync_mode: + pipeline_run_context_pb = node_contexts.contexts.add() + pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME + # TODO(kennethyang): Miragte pipeline run id to structural_runtime_parameter + # To keep existing IR textprotos used in tests unchanged, we only use + # structural_runtime_parameter for subpipelines. After the subpipeline being + # implemented, we will need to migrate normal pipelines to + # structural_runtime_parameter as well for consistency. Similar for below. + if pipeline_ctx.is_subpipeline: + compiler_utils.set_structural_runtime_parameter_pb( + pipeline_run_context_pb.name.structural_runtime_parameter, + [ + f"{pipeline_ctx.pipeline_info.pipeline_context_name}_", + (constants.PIPELINE_RUN_ID_PARAMETER_NAME, str), + ], + ) + else: + compiler_utils.set_runtime_parameter_pb( + pipeline_run_context_pb.name.runtime_parameter, + constants.PIPELINE_RUN_ID_PARAMETER_NAME, + str, + ) + + # Contexts inherited from the parent pipelines. + for i, parent_pipeline in enumerate(pipeline_ctx.parent_pipelines[::-1]): + parent_pipeline_context_pb = node_contexts.contexts.add() + parent_pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME + parent_pipeline_context_pb.name.field_value.string_value = ( + parent_pipeline.pipeline_info.pipeline_context_name + ) + + if parent_pipeline.execution_mode == pipeline.ExecutionMode.SYNC: + pipeline_run_context_pb = node_contexts.contexts.add() + pipeline_run_context_pb.type.name = ( + constants.PIPELINE_RUN_CONTEXT_TYPE_NAME + ) + + # TODO(kennethyang): Miragte pipeline run id to structural runtime + # parameter for the similar reason mentioned above. + # Use structural runtime parameter to represent pipeline_run_id except + # for the root level pipeline, for backward compatibility. + if i == len(pipeline_ctx.parent_pipelines) - 1: + compiler_utils.set_runtime_parameter_pb( + pipeline_run_context_pb.name.runtime_parameter, + constants.PIPELINE_RUN_ID_PARAMETER_NAME, + str, + ) + else: + compiler_utils.set_structural_runtime_parameter_pb( + pipeline_run_context_pb.name.structural_runtime_parameter, + [ + f"{parent_pipeline.pipeline_info.pipeline_context_name}_", + (constants.PIPELINE_RUN_ID_PARAMETER_NAME, str), + ], + ) + + # Context for the node, across pipeline runs. + node_context_pb = node_contexts.contexts.add() + node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME + node_context_pb.name.field_value.string_value = ( + compiler_utils.node_context_name( + pipeline_ctx.pipeline_info.pipeline_context_name, node_id + ) + ) + pipeline_ctx.node_context_protos_cache[node_id] = node_contexts + return node_contexts diff --git a/tfx/dsl/compiler/node_contexts_compiler_test.py b/tfx/dsl/compiler/node_contexts_compiler_test.py new file mode 100644 index 0000000000..c30d9d50df --- /dev/null +++ b/tfx/dsl/compiler/node_contexts_compiler_test.py @@ -0,0 +1,157 @@ +# Copyright 2024 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for tfx.dsl.compiler.node_contexts_compiler.""" + +import tensorflow as tf +from tfx.dsl.compiler import compiler_context +from tfx.dsl.compiler import node_contexts_compiler +from tfx.orchestration import pipeline +from tfx.proto.orchestration import pipeline_pb2 + +from google.protobuf import text_format + +_NODE_ID = 'test_node' +_PIPELINE_NAME = 'test_pipeline' + + +class NodeContextsCompilerTest(tf.test.TestCase): + + def test_compile_node_contexts(self): + expected_node_contexts = text_format.Parse( + """ + contexts { + type { + name: "pipeline" + } + name { + field_value { + string_value: "test_pipeline" + } + } + } + contexts { + type { + name: "pipeline_run" + } + name { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + contexts { + type { + name: "node" + } + name { + field_value { + string_value: "test_pipeline.test_node" + } + } + } + """, + pipeline_pb2.NodeContexts(), + ) + self.assertProtoEquals( + node_contexts_compiler.compile_node_contexts( + compiler_context.PipelineContext(pipeline.Pipeline(_PIPELINE_NAME)), + _NODE_ID, + ), + expected_node_contexts, + ) + + def test_compile_node_contexts_for_subpipeline(self): + parent_context = compiler_context.PipelineContext( + pipeline.Pipeline(_PIPELINE_NAME) + ) + subpipeline_context = compiler_context.PipelineContext( + pipeline.Pipeline('subpipeline'), parent_context + ) + + expected_node_contexts = text_format.Parse( + """ + contexts { + type { + name: "pipeline" + } + name { + field_value { + string_value: "subpipeline" + } + } + } + contexts { + type { + name: "pipeline_run" + } + name { + structural_runtime_parameter { + parts { + constant_value: "subpipeline_" + } + parts { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + } + } + contexts { + type { + name: "pipeline" + } + name { + field_value { + string_value: "test_pipeline" + } + } + } + contexts { + type { + name: "pipeline_run" + } + name { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + contexts { + type { + name: "node" + } + name { + field_value { + string_value: "subpipeline.test_node" + } + } + } + """, + pipeline_pb2.NodeContexts(), + ) + self.assertProtoEquals( + node_contexts_compiler.compile_node_contexts( + subpipeline_context, + _NODE_ID, + ), + expected_node_contexts, + ) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tfx/dsl/compiler/node_inputs_compiler.py b/tfx/dsl/compiler/node_inputs_compiler.py index 6a1d2bf4ce..bd6423ecae 100644 --- a/tfx/dsl/compiler/node_inputs_compiler.py +++ b/tfx/dsl/compiler/node_inputs_compiler.py @@ -13,12 +13,14 @@ # limitations under the License. """Compiler submodule specialized for NodeInputs.""" +from collections.abc import Iterable from typing import Type, cast from tfx import types from tfx.dsl.compiler import compiler_context from tfx.dsl.compiler import compiler_utils from tfx.dsl.compiler import constants +from tfx.dsl.compiler import node_contexts_compiler from tfx.dsl.components.base import base_component from tfx.dsl.components.base import base_node from tfx.dsl.experimental.conditionals import conditional @@ -37,6 +39,17 @@ from tfx.utils import name_utils from tfx.utils import typing_utils +from ml_metadata.proto import metadata_store_pb2 + + +def _get_tfx_value(value: str) -> pipeline_pb2.Value: + """Returns a TFX Value containing the provided string.""" + return pipeline_pb2.Value( + field_value=data_types_utils.set_metadata_value( + metadata_store_pb2.Value(), value + ) + ) + def _compile_input_graph( pipeline_ctx: compiler_context.PipelineContext, @@ -121,6 +134,17 @@ def compile_op_node(op_node: resolver_op.OpNode): return input_graph_id +def _compile_channel_pb_contexts( + context_types_and_names: Iterable[tuple[str, pipeline_pb2.Value]], + result: pipeline_pb2.InputSpec.Channel, +): + """Adds contexts to the channel.""" + for context_type, context_value in context_types_and_names: + ctx = result.context_queries.add() + ctx.type.name = context_type + ctx.name.CopyFrom(context_value) + + def _compile_channel_pb( artifact_type: Type[types.Artifact], pipeline_name: str, @@ -133,15 +157,19 @@ def _compile_channel_pb( result.artifact_query.type.CopyFrom(mlmd_artifact_type) result.artifact_query.type.ClearField('properties') - ctx = result.context_queries.add() - ctx.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME - ctx.name.field_value.string_value = pipeline_name - + contexts_types_and_values = [ + (constants.PIPELINE_CONTEXT_TYPE_NAME, _get_tfx_value(pipeline_name)) + ] if node_id: - ctx = result.context_queries.add() - ctx.type.name = constants.NODE_CONTEXT_TYPE_NAME - ctx.name.field_value.string_value = compiler_utils.node_context_name( - pipeline_name, node_id) + contexts_types_and_values.append( + ( + constants.NODE_CONTEXT_TYPE_NAME, + _get_tfx_value( + compiler_utils.node_context_name(pipeline_name, node_id) + ), + ), + ) + _compile_channel_pb_contexts(contexts_types_and_values, result) if output_key: result.output_key = output_key @@ -198,7 +226,8 @@ def _compile_input_spec( pipeline_name=channel.pipeline.id, node_id=channel.wrapped.producer_component_id, output_key=channel.output_key, - result=result.inputs[input_key].channels.add()) + result=result.inputs[input_key].channels.add(), + ) elif isinstance(channel, channel_types.ExternalPipelineChannel): channel = cast(channel_types.ExternalPipelineChannel, channel) @@ -208,12 +237,17 @@ def _compile_input_spec( pipeline_name=channel.pipeline_name, node_id=channel.producer_component_id, output_key=channel.output_key, - result=result_input_channel) + result=result_input_channel, + ) if channel.pipeline_run_id: - ctx = result_input_channel.context_queries.add() - ctx.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME - ctx.name.field_value.string_value = channel.pipeline_run_id + _compile_channel_pb_contexts( + [( + constants.PIPELINE_RUN_CONTEXT_TYPE_NAME, + _get_tfx_value(channel.pipeline_run_id), + )], + result_input_channel, + ) if pipeline_ctx.pipeline.platform_config: project_config = ( @@ -235,6 +269,33 @@ def _compile_input_spec( ) result_input_channel.metadata_connection_config.Pack(config) + # Note that this path is *usually* not taken, as most output channels already + # exist in pipeline_ctx.channels, as they are added in after + # compiler._generate_input_spec_for_outputs is called. + # This path gets taken when a channel is copied, for example by + # `as_optional()`, as Channel uses `id` for a hash. + elif isinstance(channel, channel_types.OutputChannel): + channel = cast(channel_types.Channel, channel) + result_input_channel = result.inputs[input_key].channels.add() + _compile_channel_pb( + artifact_type=channel.type, + pipeline_name=pipeline_ctx.pipeline_info.pipeline_name, + node_id=channel.producer_component_id, + output_key=channel.output_key, + result=result_input_channel, + ) + node_contexts = node_contexts_compiler.compile_node_contexts( + pipeline_ctx, tfx_node.id + ) + contexts_to_add = [] + for context_spec in node_contexts.contexts: + if context_spec.type.name == constants.PIPELINE_RUN_CONTEXT_TYPE_NAME: + contexts_to_add.append(( + constants.PIPELINE_RUN_CONTEXT_TYPE_NAME, + context_spec.name, + )) + _compile_channel_pb_contexts(contexts_to_add, result_input_channel) + elif isinstance(channel, channel_types.Channel): channel = cast(channel_types.Channel, channel) _compile_channel_pb( @@ -242,7 +303,8 @@ def _compile_input_spec( pipeline_name=pipeline_ctx.pipeline_info.pipeline_name, node_id=channel.producer_component_id, output_key=channel.output_key, - result=result.inputs[input_key].channels.add()) + result=result.inputs[input_key].channels.add(), + ) elif isinstance(channel, channel_types.UnionChannel): channel = cast(channel_types.UnionChannel, channel) diff --git a/tfx/dsl/compiler/node_inputs_compiler_test.py b/tfx/dsl/compiler/node_inputs_compiler_test.py index d2b3301cd3..5bb2844e4f 100644 --- a/tfx/dsl/compiler/node_inputs_compiler_test.py +++ b/tfx/dsl/compiler/node_inputs_compiler_test.py @@ -145,7 +145,8 @@ def _get_channel_pb( pipeline_name=pipeline_name or self.pipeline_name, node_id=node_id, output_key=output_key, - result=result) + result=result, + ) return result def testCompileAlreadyCompiledInputs(self): diff --git a/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline.py b/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline.py index e9b51b46a4..43ef1ce814 100644 --- a/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline.py +++ b/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline.py @@ -112,6 +112,15 @@ def create_test_pipeline(): mandatory=upstream_component.outputs['first_model'], optional_but_needed=upstream_component.outputs['second_model'], optional_and_not_needed=upstream_component.outputs['third_model']) + as_optional_component = MyComponent( + mandatory=upstream_component.outputs['second_model'].as_optional(), + optional_but_needed=upstream_component.outputs[ + 'second_model' + ].as_optional(), + optional_and_not_needed=upstream_component.outputs[ + 'third_model' + ].as_optional(), + ).with_id('as_optional_component') p_in = pipeline.PipelineInputs({ 'mandatory': upstream_component.outputs['first_model'], 'optional': upstream_component.outputs['second_model'].as_optional(), @@ -129,5 +138,10 @@ def create_test_pipeline(): return pipeline.Pipeline( pipeline_name=_pipeline_name, pipeline_root=_pipeline_root, - components=[upstream_component, my_component, subpipeline], + components=[ + upstream_component, + my_component, + as_optional_component, + subpipeline, + ], ) diff --git a/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline_input_v2_ir.pbtxt b/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline_input_v2_ir.pbtxt index 0355afd2f5..2cff1ca2f3 100644 --- a/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline_input_v2_ir.pbtxt +++ b/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline_input_v2_ir.pbtxt @@ -84,6 +84,7 @@ nodes { } } downstream_nodes: "MyComponent" + downstream_nodes: "as_optional_component" downstream_nodes: "subpipeline" execution_options { caching_options { @@ -286,6 +287,191 @@ nodes { } } } +nodes { + pipeline_node { + node_info { + type { + name: "tfx.dsl.compiler.testdata.optional_and_allow_empty_pipeline.MyComponent" + } + id: "as_optional_component" + } + contexts { + contexts { + type { + name: "pipeline" + } + name { + field_value { + string_value: "optional_and_allow_empty_pipeline" + } + } + } + contexts { + type { + name: "pipeline_run" + } + name { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + contexts { + type { + name: "node" + } + name { + field_value { + string_value: "optional_and_allow_empty_pipeline.as_optional_component" + } + } + } + } + inputs { + inputs { + key: "mandatory" + value { + channels { + context_queries { + type { + name: "pipeline" + } + name { + field_value { + string_value: "optional_and_allow_empty_pipeline" + } + } + } + context_queries { + type { + name: "node" + } + name { + field_value { + string_value: "optional_and_allow_empty_pipeline.UpstreamComponent" + } + } + } + context_queries { + type { + name: "pipeline_run" + } + name { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + artifact_query { + type { + name: "Model" + base_type: MODEL + } + } + output_key: "second_model" + } + } + } + inputs { + key: "optional_and_not_needed" + value { + channels { + context_queries { + type { + name: "pipeline" + } + name { + field_value { + string_value: "optional_and_allow_empty_pipeline" + } + } + } + context_queries { + type { + name: "node" + } + name { + field_value { + string_value: "optional_and_allow_empty_pipeline.UpstreamComponent" + } + } + } + context_queries { + type { + name: "pipeline_run" + } + name { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + artifact_query { + type { + name: "Model" + base_type: MODEL + } + } + output_key: "third_model" + } + } + } + inputs { + key: "optional_but_needed" + value { + channels { + context_queries { + type { + name: "pipeline" + } + name { + field_value { + string_value: "optional_and_allow_empty_pipeline" + } + } + } + context_queries { + type { + name: "node" + } + name { + field_value { + string_value: "optional_and_allow_empty_pipeline.UpstreamComponent" + } + } + } + context_queries { + type { + name: "pipeline_run" + } + name { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + artifact_query { + type { + name: "Model" + base_type: MODEL + } + } + output_key: "second_model" + } + } + } + } + upstream_nodes: "UpstreamComponent" + execution_options { + caching_options { + } + } + } +} nodes { sub_pipeline { pipeline_info { @@ -434,6 +620,17 @@ nodes { } } } + context_queries { + type { + name: "pipeline_run" + } + name { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } artifact_query { type { name: "Model" @@ -854,5 +1051,13 @@ deployment_config { } } } + executor_specs { + key: "as_optional_component" + value { + [type.googleapis.com/tfx.orchestration.executable_spec.PythonClassExecutableSpec] { + class_path: "tfx.dsl.compiler.testdata.optional_and_allow_empty_pipeline.Executor" + } + } + } } }