From 9ccee8273746938778675eff0fc29a9e93b65f04 Mon Sep 17 00:00:00 2001 From: kmonte Date: Mon, 15 Apr 2024 15:32:00 -0700 Subject: [PATCH] Add subpipeline_utils to contain utils for orchestrating subpipelines PiperOrigin-RevId: 625103692 --- tfx/dsl/compiler/compiler.py | 87 ++++++-- tfx/dsl/compiler/compiler_context.py | 2 - tfx/dsl/compiler/node_contexts_compiler.py | 108 --------- .../compiler/node_contexts_compiler_test.py | 157 -------------- tfx/dsl/compiler/node_inputs_compiler.py | 90 ++------ tfx/dsl/compiler/node_inputs_compiler_test.py | 3 +- .../optional_and_allow_empty_pipeline.py | 16 +- ...and_allow_empty_pipeline_input_v2_ir.pbtxt | 205 ------------------ tfx/orchestration/subpipeline_utils.py | 54 +++++ tfx/orchestration/subpipeline_utils_test.py | 47 ++++ 10 files changed, 185 insertions(+), 584 deletions(-) delete mode 100644 tfx/dsl/compiler/node_contexts_compiler.py delete mode 100644 tfx/dsl/compiler/node_contexts_compiler_test.py create mode 100644 tfx/orchestration/subpipeline_utils.py create mode 100644 tfx/orchestration/subpipeline_utils_test.py diff --git a/tfx/dsl/compiler/compiler.py b/tfx/dsl/compiler/compiler.py index 7e4bf0c97a5..4af95be5af7 100644 --- a/tfx/dsl/compiler/compiler.py +++ b/tfx/dsl/compiler/compiler.py @@ -19,7 +19,6 @@ from tfx.dsl.compiler import compiler_context from tfx.dsl.compiler import compiler_utils from tfx.dsl.compiler import constants -from tfx.dsl.compiler import node_contexts_compiler from tfx.dsl.compiler import node_execution_options_utils from tfx.dsl.compiler import node_inputs_compiler from tfx.dsl.components.base import base_component @@ -57,12 +56,7 @@ def _compile_pipeline_begin_node( # Step 2: Node Context # Inner pipeline's contexts. - node.contexts.CopyFrom( - node_contexts_compiler.compile_node_contexts( - pipeline_ctx, - node.node_info.id, - ) - ) + _set_node_context(node, pipeline_ctx) # Step 3: Node inputs # Pipeline node inputs are stored as the inputs of the PipelineBegin node. @@ -127,12 +121,7 @@ def _compile_pipeline_end_node( # Step 2: Node Context # Inner pipeline's contexts. - node.contexts.CopyFrom( - node_contexts_compiler.compile_node_contexts( - pipeline_ctx, - node.node_info.id, - ) - ) + _set_node_context(node, pipeline_ctx) # Step 3: Node inputs node_inputs_compiler.compile_node_inputs( @@ -205,12 +194,7 @@ def _compile_node( node.node_info.id = tfx_node.id # Step 2: Node Context - node.contexts.CopyFrom( - node_contexts_compiler.compile_node_contexts( - pipeline_ctx, - node.node_info.id, - ) - ) + _set_node_context(node, pipeline_ctx) # Step 3: Node inputs node_inputs_compiler.compile_node_inputs( @@ -402,6 +386,71 @@ def _validate_pipeline(tfx_pipeline: pipeline.Pipeline, raise ValueError("Subpipeline has to be Sync execution mode.") +def _set_node_context(node: pipeline_pb2.PipelineNode, + pipeline_ctx: compiler_context.PipelineContext): + """Compiles the node contexts of a pipeline node.""" + # Context for the pipeline, across pipeline runs. + pipeline_context_pb = node.contexts.contexts.add() + pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME + pipeline_context_pb.name.field_value.string_value = ( + pipeline_ctx.pipeline_info.pipeline_context_name) + + # Context for the current pipeline run. + if pipeline_ctx.is_sync_mode: + pipeline_run_context_pb = node.contexts.contexts.add() + pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME + # TODO(kennethyang): Miragte pipeline run id to structural_runtime_parameter + # To keep existing IR textprotos used in tests unchanged, we only use + # structural_runtime_parameter for subpipelines. After the subpipeline being + # implemented, we will need to migrate normal pipelines to + # structural_runtime_parameter as well for consistency. Similar for below. + if pipeline_ctx.is_subpipeline: + compiler_utils.set_structural_runtime_parameter_pb( + pipeline_run_context_pb.name.structural_runtime_parameter, [ + f"{pipeline_ctx.pipeline_info.pipeline_context_name}_", + (constants.PIPELINE_RUN_ID_PARAMETER_NAME, str) + ]) + else: + compiler_utils.set_runtime_parameter_pb( + pipeline_run_context_pb.name.runtime_parameter, + constants.PIPELINE_RUN_ID_PARAMETER_NAME, str) + + # Contexts inherited from the parent pipelines. + for i, parent_pipeline in enumerate(pipeline_ctx.parent_pipelines[::-1]): + parent_pipeline_context_pb = node.contexts.contexts.add() + parent_pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME + parent_pipeline_context_pb.name.field_value.string_value = ( + parent_pipeline.pipeline_info.pipeline_context_name) + + if parent_pipeline.execution_mode == pipeline.ExecutionMode.SYNC: + pipeline_run_context_pb = node.contexts.contexts.add() + pipeline_run_context_pb.type.name = ( + constants.PIPELINE_RUN_CONTEXT_TYPE_NAME) + + # TODO(kennethyang): Miragte pipeline run id to structural runtime + # parameter for the similar reason mentioned above. + # Use structural runtime parameter to represent pipeline_run_id except + # for the root level pipeline, for backward compatibility. + if i == len(pipeline_ctx.parent_pipelines) - 1: + compiler_utils.set_runtime_parameter_pb( + pipeline_run_context_pb.name.runtime_parameter, + constants.PIPELINE_RUN_ID_PARAMETER_NAME, str) + else: + compiler_utils.set_structural_runtime_parameter_pb( + pipeline_run_context_pb.name.structural_runtime_parameter, [ + f"{parent_pipeline.pipeline_info.pipeline_context_name}_", + (constants.PIPELINE_RUN_ID_PARAMETER_NAME, str) + ]) + + # Context for the node, across pipeline runs. + node_context_pb = node.contexts.contexts.add() + node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME + node_context_pb.name.field_value.string_value = ( + compiler_utils.node_context_name( + pipeline_ctx.pipeline_info.pipeline_context_name, + node.node_info.id)) + + def _set_node_outputs(node: pipeline_pb2.PipelineNode, tfx_node_outputs: Dict[str, types.Channel]): """Compiles the outputs of a pipeline node.""" diff --git a/tfx/dsl/compiler/compiler_context.py b/tfx/dsl/compiler/compiler_context.py index 8549d79c2ea..17193cb4f25 100644 --- a/tfx/dsl/compiler/compiler_context.py +++ b/tfx/dsl/compiler/compiler_context.py @@ -55,8 +55,6 @@ def __init__(self, # Mapping from Channel object to compiled Channel proto. self.channels = dict() - self.node_context_protos_cache: dict[str, pipeline_pb2.NodeContexts] = {} - # Node ID -> NodeContext self._node_contexts: Dict[str, NodeContext] = {} diff --git a/tfx/dsl/compiler/node_contexts_compiler.py b/tfx/dsl/compiler/node_contexts_compiler.py deleted file mode 100644 index 73e73ea032f..00000000000 --- a/tfx/dsl/compiler/node_contexts_compiler.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2024 Google LLC. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Compiles NodeContexts.""" - -from tfx.dsl.compiler import compiler_context -from tfx.dsl.compiler import compiler_utils -from tfx.dsl.compiler import constants -from tfx.orchestration import pipeline -from tfx.proto.orchestration import pipeline_pb2 - - -def compile_node_contexts( - pipeline_ctx: compiler_context.PipelineContext, - node_id: str, -) -> pipeline_pb2.NodeContexts: - """Compiles the node contexts of a pipeline node.""" - - if pipeline_ctx.pipeline_info is None: - return pipeline_pb2.NodeContexts() - if maybe_contexts := pipeline_ctx.node_context_protos_cache.get(node_id): - return maybe_contexts - - node_contexts = pipeline_pb2.NodeContexts() - # Context for the pipeline, across pipeline runs. - pipeline_context_pb = node_contexts.contexts.add() - pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME - pipeline_context_pb.name.field_value.string_value = ( - pipeline_ctx.pipeline_info.pipeline_context_name - ) - - # Context for the current pipeline run. - if pipeline_ctx.is_sync_mode: - pipeline_run_context_pb = node_contexts.contexts.add() - pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME - # TODO(kennethyang): Miragte pipeline run id to structural_runtime_parameter - # To keep existing IR textprotos used in tests unchanged, we only use - # structural_runtime_parameter for subpipelines. After the subpipeline being - # implemented, we will need to migrate normal pipelines to - # structural_runtime_parameter as well for consistency. Similar for below. - if pipeline_ctx.is_subpipeline: - compiler_utils.set_structural_runtime_parameter_pb( - pipeline_run_context_pb.name.structural_runtime_parameter, - [ - f"{pipeline_ctx.pipeline_info.pipeline_context_name}_", - (constants.PIPELINE_RUN_ID_PARAMETER_NAME, str), - ], - ) - else: - compiler_utils.set_runtime_parameter_pb( - pipeline_run_context_pb.name.runtime_parameter, - constants.PIPELINE_RUN_ID_PARAMETER_NAME, - str, - ) - - # Contexts inherited from the parent pipelines. - for i, parent_pipeline in enumerate(pipeline_ctx.parent_pipelines[::-1]): - parent_pipeline_context_pb = node_contexts.contexts.add() - parent_pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME - parent_pipeline_context_pb.name.field_value.string_value = ( - parent_pipeline.pipeline_info.pipeline_context_name - ) - - if parent_pipeline.execution_mode == pipeline.ExecutionMode.SYNC: - pipeline_run_context_pb = node_contexts.contexts.add() - pipeline_run_context_pb.type.name = ( - constants.PIPELINE_RUN_CONTEXT_TYPE_NAME - ) - - # TODO(kennethyang): Miragte pipeline run id to structural runtime - # parameter for the similar reason mentioned above. - # Use structural runtime parameter to represent pipeline_run_id except - # for the root level pipeline, for backward compatibility. - if i == len(pipeline_ctx.parent_pipelines) - 1: - compiler_utils.set_runtime_parameter_pb( - pipeline_run_context_pb.name.runtime_parameter, - constants.PIPELINE_RUN_ID_PARAMETER_NAME, - str, - ) - else: - compiler_utils.set_structural_runtime_parameter_pb( - pipeline_run_context_pb.name.structural_runtime_parameter, - [ - f"{parent_pipeline.pipeline_info.pipeline_context_name}_", - (constants.PIPELINE_RUN_ID_PARAMETER_NAME, str), - ], - ) - - # Context for the node, across pipeline runs. - node_context_pb = node_contexts.contexts.add() - node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME - node_context_pb.name.field_value.string_value = ( - compiler_utils.node_context_name( - pipeline_ctx.pipeline_info.pipeline_context_name, node_id - ) - ) - pipeline_ctx.node_context_protos_cache[node_id] = node_contexts - return node_contexts diff --git a/tfx/dsl/compiler/node_contexts_compiler_test.py b/tfx/dsl/compiler/node_contexts_compiler_test.py deleted file mode 100644 index c30d9d50df9..00000000000 --- a/tfx/dsl/compiler/node_contexts_compiler_test.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2024 Google LLC. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for tfx.dsl.compiler.node_contexts_compiler.""" - -import tensorflow as tf -from tfx.dsl.compiler import compiler_context -from tfx.dsl.compiler import node_contexts_compiler -from tfx.orchestration import pipeline -from tfx.proto.orchestration import pipeline_pb2 - -from google.protobuf import text_format - -_NODE_ID = 'test_node' -_PIPELINE_NAME = 'test_pipeline' - - -class NodeContextsCompilerTest(tf.test.TestCase): - - def test_compile_node_contexts(self): - expected_node_contexts = text_format.Parse( - """ - contexts { - type { - name: "pipeline" - } - name { - field_value { - string_value: "test_pipeline" - } - } - } - contexts { - type { - name: "pipeline_run" - } - name { - runtime_parameter { - name: "pipeline-run-id" - type: STRING - } - } - } - contexts { - type { - name: "node" - } - name { - field_value { - string_value: "test_pipeline.test_node" - } - } - } - """, - pipeline_pb2.NodeContexts(), - ) - self.assertProtoEquals( - node_contexts_compiler.compile_node_contexts( - compiler_context.PipelineContext(pipeline.Pipeline(_PIPELINE_NAME)), - _NODE_ID, - ), - expected_node_contexts, - ) - - def test_compile_node_contexts_for_subpipeline(self): - parent_context = compiler_context.PipelineContext( - pipeline.Pipeline(_PIPELINE_NAME) - ) - subpipeline_context = compiler_context.PipelineContext( - pipeline.Pipeline('subpipeline'), parent_context - ) - - expected_node_contexts = text_format.Parse( - """ - contexts { - type { - name: "pipeline" - } - name { - field_value { - string_value: "subpipeline" - } - } - } - contexts { - type { - name: "pipeline_run" - } - name { - structural_runtime_parameter { - parts { - constant_value: "subpipeline_" - } - parts { - runtime_parameter { - name: "pipeline-run-id" - type: STRING - } - } - } - } - } - contexts { - type { - name: "pipeline" - } - name { - field_value { - string_value: "test_pipeline" - } - } - } - contexts { - type { - name: "pipeline_run" - } - name { - runtime_parameter { - name: "pipeline-run-id" - type: STRING - } - } - } - contexts { - type { - name: "node" - } - name { - field_value { - string_value: "subpipeline.test_node" - } - } - } - """, - pipeline_pb2.NodeContexts(), - ) - self.assertProtoEquals( - node_contexts_compiler.compile_node_contexts( - subpipeline_context, - _NODE_ID, - ), - expected_node_contexts, - ) - - -if __name__ == '__main__': - tf.test.main() diff --git a/tfx/dsl/compiler/node_inputs_compiler.py b/tfx/dsl/compiler/node_inputs_compiler.py index bd6423ecae4..6a1d2bf4cef 100644 --- a/tfx/dsl/compiler/node_inputs_compiler.py +++ b/tfx/dsl/compiler/node_inputs_compiler.py @@ -13,14 +13,12 @@ # limitations under the License. """Compiler submodule specialized for NodeInputs.""" -from collections.abc import Iterable from typing import Type, cast from tfx import types from tfx.dsl.compiler import compiler_context from tfx.dsl.compiler import compiler_utils from tfx.dsl.compiler import constants -from tfx.dsl.compiler import node_contexts_compiler from tfx.dsl.components.base import base_component from tfx.dsl.components.base import base_node from tfx.dsl.experimental.conditionals import conditional @@ -39,17 +37,6 @@ from tfx.utils import name_utils from tfx.utils import typing_utils -from ml_metadata.proto import metadata_store_pb2 - - -def _get_tfx_value(value: str) -> pipeline_pb2.Value: - """Returns a TFX Value containing the provided string.""" - return pipeline_pb2.Value( - field_value=data_types_utils.set_metadata_value( - metadata_store_pb2.Value(), value - ) - ) - def _compile_input_graph( pipeline_ctx: compiler_context.PipelineContext, @@ -134,17 +121,6 @@ def compile_op_node(op_node: resolver_op.OpNode): return input_graph_id -def _compile_channel_pb_contexts( - context_types_and_names: Iterable[tuple[str, pipeline_pb2.Value]], - result: pipeline_pb2.InputSpec.Channel, -): - """Adds contexts to the channel.""" - for context_type, context_value in context_types_and_names: - ctx = result.context_queries.add() - ctx.type.name = context_type - ctx.name.CopyFrom(context_value) - - def _compile_channel_pb( artifact_type: Type[types.Artifact], pipeline_name: str, @@ -157,19 +133,15 @@ def _compile_channel_pb( result.artifact_query.type.CopyFrom(mlmd_artifact_type) result.artifact_query.type.ClearField('properties') - contexts_types_and_values = [ - (constants.PIPELINE_CONTEXT_TYPE_NAME, _get_tfx_value(pipeline_name)) - ] + ctx = result.context_queries.add() + ctx.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME + ctx.name.field_value.string_value = pipeline_name + if node_id: - contexts_types_and_values.append( - ( - constants.NODE_CONTEXT_TYPE_NAME, - _get_tfx_value( - compiler_utils.node_context_name(pipeline_name, node_id) - ), - ), - ) - _compile_channel_pb_contexts(contexts_types_and_values, result) + ctx = result.context_queries.add() + ctx.type.name = constants.NODE_CONTEXT_TYPE_NAME + ctx.name.field_value.string_value = compiler_utils.node_context_name( + pipeline_name, node_id) if output_key: result.output_key = output_key @@ -226,8 +198,7 @@ def _compile_input_spec( pipeline_name=channel.pipeline.id, node_id=channel.wrapped.producer_component_id, output_key=channel.output_key, - result=result.inputs[input_key].channels.add(), - ) + result=result.inputs[input_key].channels.add()) elif isinstance(channel, channel_types.ExternalPipelineChannel): channel = cast(channel_types.ExternalPipelineChannel, channel) @@ -237,17 +208,12 @@ def _compile_input_spec( pipeline_name=channel.pipeline_name, node_id=channel.producer_component_id, output_key=channel.output_key, - result=result_input_channel, - ) + result=result_input_channel) if channel.pipeline_run_id: - _compile_channel_pb_contexts( - [( - constants.PIPELINE_RUN_CONTEXT_TYPE_NAME, - _get_tfx_value(channel.pipeline_run_id), - )], - result_input_channel, - ) + ctx = result_input_channel.context_queries.add() + ctx.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME + ctx.name.field_value.string_value = channel.pipeline_run_id if pipeline_ctx.pipeline.platform_config: project_config = ( @@ -269,33 +235,6 @@ def _compile_input_spec( ) result_input_channel.metadata_connection_config.Pack(config) - # Note that this path is *usually* not taken, as most output channels already - # exist in pipeline_ctx.channels, as they are added in after - # compiler._generate_input_spec_for_outputs is called. - # This path gets taken when a channel is copied, for example by - # `as_optional()`, as Channel uses `id` for a hash. - elif isinstance(channel, channel_types.OutputChannel): - channel = cast(channel_types.Channel, channel) - result_input_channel = result.inputs[input_key].channels.add() - _compile_channel_pb( - artifact_type=channel.type, - pipeline_name=pipeline_ctx.pipeline_info.pipeline_name, - node_id=channel.producer_component_id, - output_key=channel.output_key, - result=result_input_channel, - ) - node_contexts = node_contexts_compiler.compile_node_contexts( - pipeline_ctx, tfx_node.id - ) - contexts_to_add = [] - for context_spec in node_contexts.contexts: - if context_spec.type.name == constants.PIPELINE_RUN_CONTEXT_TYPE_NAME: - contexts_to_add.append(( - constants.PIPELINE_RUN_CONTEXT_TYPE_NAME, - context_spec.name, - )) - _compile_channel_pb_contexts(contexts_to_add, result_input_channel) - elif isinstance(channel, channel_types.Channel): channel = cast(channel_types.Channel, channel) _compile_channel_pb( @@ -303,8 +242,7 @@ def _compile_input_spec( pipeline_name=pipeline_ctx.pipeline_info.pipeline_name, node_id=channel.producer_component_id, output_key=channel.output_key, - result=result.inputs[input_key].channels.add(), - ) + result=result.inputs[input_key].channels.add()) elif isinstance(channel, channel_types.UnionChannel): channel = cast(channel_types.UnionChannel, channel) diff --git a/tfx/dsl/compiler/node_inputs_compiler_test.py b/tfx/dsl/compiler/node_inputs_compiler_test.py index 5bb2844e4f0..d2b3301cd30 100644 --- a/tfx/dsl/compiler/node_inputs_compiler_test.py +++ b/tfx/dsl/compiler/node_inputs_compiler_test.py @@ -145,8 +145,7 @@ def _get_channel_pb( pipeline_name=pipeline_name or self.pipeline_name, node_id=node_id, output_key=output_key, - result=result, - ) + result=result) return result def testCompileAlreadyCompiledInputs(self): diff --git a/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline.py b/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline.py index 43ef1ce8145..e9b51b46a43 100644 --- a/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline.py +++ b/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline.py @@ -112,15 +112,6 @@ def create_test_pipeline(): mandatory=upstream_component.outputs['first_model'], optional_but_needed=upstream_component.outputs['second_model'], optional_and_not_needed=upstream_component.outputs['third_model']) - as_optional_component = MyComponent( - mandatory=upstream_component.outputs['second_model'].as_optional(), - optional_but_needed=upstream_component.outputs[ - 'second_model' - ].as_optional(), - optional_and_not_needed=upstream_component.outputs[ - 'third_model' - ].as_optional(), - ).with_id('as_optional_component') p_in = pipeline.PipelineInputs({ 'mandatory': upstream_component.outputs['first_model'], 'optional': upstream_component.outputs['second_model'].as_optional(), @@ -138,10 +129,5 @@ def create_test_pipeline(): return pipeline.Pipeline( pipeline_name=_pipeline_name, pipeline_root=_pipeline_root, - components=[ - upstream_component, - my_component, - as_optional_component, - subpipeline, - ], + components=[upstream_component, my_component, subpipeline], ) diff --git a/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline_input_v2_ir.pbtxt b/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline_input_v2_ir.pbtxt index 2cff1ca2f3c..0355afd2f5d 100644 --- a/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline_input_v2_ir.pbtxt +++ b/tfx/dsl/compiler/testdata/optional_and_allow_empty_pipeline_input_v2_ir.pbtxt @@ -84,7 +84,6 @@ nodes { } } downstream_nodes: "MyComponent" - downstream_nodes: "as_optional_component" downstream_nodes: "subpipeline" execution_options { caching_options { @@ -287,191 +286,6 @@ nodes { } } } -nodes { - pipeline_node { - node_info { - type { - name: "tfx.dsl.compiler.testdata.optional_and_allow_empty_pipeline.MyComponent" - } - id: "as_optional_component" - } - contexts { - contexts { - type { - name: "pipeline" - } - name { - field_value { - string_value: "optional_and_allow_empty_pipeline" - } - } - } - contexts { - type { - name: "pipeline_run" - } - name { - runtime_parameter { - name: "pipeline-run-id" - type: STRING - } - } - } - contexts { - type { - name: "node" - } - name { - field_value { - string_value: "optional_and_allow_empty_pipeline.as_optional_component" - } - } - } - } - inputs { - inputs { - key: "mandatory" - value { - channels { - context_queries { - type { - name: "pipeline" - } - name { - field_value { - string_value: "optional_and_allow_empty_pipeline" - } - } - } - context_queries { - type { - name: "node" - } - name { - field_value { - string_value: "optional_and_allow_empty_pipeline.UpstreamComponent" - } - } - } - context_queries { - type { - name: "pipeline_run" - } - name { - runtime_parameter { - name: "pipeline-run-id" - type: STRING - } - } - } - artifact_query { - type { - name: "Model" - base_type: MODEL - } - } - output_key: "second_model" - } - } - } - inputs { - key: "optional_and_not_needed" - value { - channels { - context_queries { - type { - name: "pipeline" - } - name { - field_value { - string_value: "optional_and_allow_empty_pipeline" - } - } - } - context_queries { - type { - name: "node" - } - name { - field_value { - string_value: "optional_and_allow_empty_pipeline.UpstreamComponent" - } - } - } - context_queries { - type { - name: "pipeline_run" - } - name { - runtime_parameter { - name: "pipeline-run-id" - type: STRING - } - } - } - artifact_query { - type { - name: "Model" - base_type: MODEL - } - } - output_key: "third_model" - } - } - } - inputs { - key: "optional_but_needed" - value { - channels { - context_queries { - type { - name: "pipeline" - } - name { - field_value { - string_value: "optional_and_allow_empty_pipeline" - } - } - } - context_queries { - type { - name: "node" - } - name { - field_value { - string_value: "optional_and_allow_empty_pipeline.UpstreamComponent" - } - } - } - context_queries { - type { - name: "pipeline_run" - } - name { - runtime_parameter { - name: "pipeline-run-id" - type: STRING - } - } - } - artifact_query { - type { - name: "Model" - base_type: MODEL - } - } - output_key: "second_model" - } - } - } - } - upstream_nodes: "UpstreamComponent" - execution_options { - caching_options { - } - } - } -} nodes { sub_pipeline { pipeline_info { @@ -620,17 +434,6 @@ nodes { } } } - context_queries { - type { - name: "pipeline_run" - } - name { - runtime_parameter { - name: "pipeline-run-id" - type: STRING - } - } - } artifact_query { type { name: "Model" @@ -1051,13 +854,5 @@ deployment_config { } } } - executor_specs { - key: "as_optional_component" - value { - [type.googleapis.com/tfx.orchestration.executable_spec.PythonClassExecutableSpec] { - class_path: "tfx.dsl.compiler.testdata.optional_and_allow_empty_pipeline.Executor" - } - } - } } } diff --git a/tfx/orchestration/subpipeline_utils.py b/tfx/orchestration/subpipeline_utils.py new file mode 100644 index 00000000000..a5598c26f04 --- /dev/null +++ b/tfx/orchestration/subpipeline_utils.py @@ -0,0 +1,54 @@ +# Copyright 2024 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Generic utilities for orchestrating subpipelines.""" + + +from tfx.dsl.compiler import compiler_utils +from tfx.dsl.compiler import constants as compiler_constants +from tfx.orchestration import pipeline as dsl_pipeline +from tfx.proto.orchestration import pipeline_pb2 + +# This pipeline *only* exists so that we can correctly infer the correct node +# types for pipeline begin and end nodes, as the compiler uses a Python Pipeline +# object to generate the names. +# This pipeline *should not* be used otherwise. +_DUMMY_PIPELINE = dsl_pipeline.Pipeline(pipeline_name="UNUSED") + + +def is_subpipeline(pipeline: pipeline_pb2.Pipeline) -> bool: + """Returns True if the pipeline is a subpipeline.""" + nodes = pipeline.nodes + if len(nodes) < 2: + return False + maybe_begin_node = nodes[0] + maybe_end_node = nodes[-1] + if ( + maybe_begin_node.WhichOneof("node") != "pipeline_node" + or maybe_begin_node.pipeline_node.node_info.id + != f"{pipeline.pipeline_info.id}{compiler_constants.PIPELINE_BEGIN_NODE_SUFFIX}" + or maybe_begin_node.pipeline_node.node_info.type.name + != compiler_utils.pipeline_begin_node_type_name(_DUMMY_PIPELINE) + ): + return False + if ( + maybe_end_node.WhichOneof("node") != "pipeline_node" + or maybe_end_node.pipeline_node.node_info.id + != compiler_utils.pipeline_end_node_id_from_pipeline_id( + pipeline.pipeline_info.id + ) + or maybe_end_node.pipeline_node.node_info.type.name + != compiler_utils.pipeline_end_node_type_name(_DUMMY_PIPELINE) + ): + return False + return True diff --git a/tfx/orchestration/subpipeline_utils_test.py b/tfx/orchestration/subpipeline_utils_test.py new file mode 100644 index 00000000000..ba7f1d57c83 --- /dev/null +++ b/tfx/orchestration/subpipeline_utils_test.py @@ -0,0 +1,47 @@ +# Copyright 2024 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for tfx.orchestration.subpipeline_utils.""" + +from absl.testing import absltest +from absl.testing import parameterized +from tfx.dsl.compiler import compiler +from tfx.orchestration import pipeline as dsl_pipeline +from tfx.orchestration import subpipeline_utils + +_PIPELINE_NAME = 'test_pipeline' +_TEST_PIPELINE = dsl_pipeline.Pipeline(pipeline_name=_PIPELINE_NAME) + + +class SubpipelineUtilsTest(parameterized.TestCase): + + def test_is_subpipeline_with_subpipeline(self): + subpipeline = dsl_pipeline.Pipeline(pipeline_name='subpipeline') + pipeline = dsl_pipeline.Pipeline( + pipeline_name=_PIPELINE_NAME, components=[subpipeline] + ) + pipeline_ir = compiler.Compiler().compile(pipeline) + subpipeline_ir = pipeline_ir.nodes[0].sub_pipeline + self.assertTrue(subpipeline_utils.is_subpipeline(subpipeline_ir)) + + def test_is_subpipeline_with_parent_pipelines(self): + subpipeline = dsl_pipeline.Pipeline(pipeline_name='subpipeline') + pipeline = dsl_pipeline.Pipeline( + pipeline_name=_PIPELINE_NAME, components=[subpipeline] + ) + pipeline_ir = compiler.Compiler().compile(pipeline) + self.assertFalse(subpipeline_utils.is_subpipeline(pipeline_ir)) + + +if __name__ == '__main__': + absltest.main()