From e86e1eec57a106e37903c43bd697de0a50ca9095 Mon Sep 17 00:00:00 2001 From: tfx-team Date: Fri, 3 May 2024 12:29:35 -0700 Subject: [PATCH] Add tag to ExternalPipelineChannel so we can get artifacts by tags. PiperOrigin-RevId: 630465771 --- tfx/dsl/compiler/compiler_test.py | 2 + tfx/dsl/compiler/node_inputs_compiler.py | 80 +++++- tfx/dsl/compiler/node_inputs_compiler_test.py | 251 ++++++++++++++++++ .../testdata/consumer_pipeline_with_tags.py | 37 +++ ...sumer_pipeline_with_tags_input_v2_ir.pbtxt | 210 +++++++++++++++ tfx/types/channel.py | 9 +- tfx/types/channel_utils.py | 27 +- 7 files changed, 610 insertions(+), 6 deletions(-) create mode 100644 tfx/dsl/compiler/testdata/consumer_pipeline_with_tags.py create mode 100644 tfx/dsl/compiler/testdata/consumer_pipeline_with_tags_input_v2_ir.pbtxt diff --git a/tfx/dsl/compiler/compiler_test.py b/tfx/dsl/compiler/compiler_test.py index 013e895b0fe..b9e5cdf6bbf 100644 --- a/tfx/dsl/compiler/compiler_test.py +++ b/tfx/dsl/compiler/compiler_test.py @@ -33,6 +33,7 @@ from tfx.dsl.compiler.testdata import conditional_pipeline from tfx.dsl.compiler.testdata import consumer_pipeline from tfx.dsl.compiler.testdata import consumer_pipeline_different_project +from tfx.dsl.compiler.testdata import consumer_pipeline_with_tags from tfx.dsl.compiler.testdata import dynamic_exec_properties_pipeline from tfx.dsl.compiler.testdata import external_artifacts_pipeline from tfx.dsl.compiler.testdata import foreach_pipeline @@ -143,6 +144,7 @@ def _get_pipeline_ir(self, filename: str) -> pipeline_pb2.Pipeline: consumer_pipeline, external_artifacts_pipeline, consumer_pipeline_different_project, + consumer_pipeline_with_tags, ]) ) def testCompile( diff --git a/tfx/dsl/compiler/node_inputs_compiler.py b/tfx/dsl/compiler/node_inputs_compiler.py index bd6423ecae4..8a2d03b6a72 100644 --- a/tfx/dsl/compiler/node_inputs_compiler.py +++ b/tfx/dsl/compiler/node_inputs_compiler.py @@ -13,8 +13,8 @@ # limitations under the License. """Compiler submodule specialized for NodeInputs.""" -from collections.abc import Iterable -from typing import Type, cast +from collections.abc import Iterable, Sequence +from typing import Optional, Type, cast from tfx import types from tfx.dsl.compiler import compiler_context @@ -41,6 +41,8 @@ from ml_metadata.proto import metadata_store_pb2 +PropertyPredicate = pipeline_pb2.PropertyPredicate + def _get_tfx_value(value: str) -> pipeline_pb2.Value: """Returns a TFX Value containing the provided string.""" @@ -137,12 +139,16 @@ def compile_op_node(op_node: resolver_op.OpNode): def _compile_channel_pb_contexts( context_types_and_names: Iterable[tuple[str, pipeline_pb2.Value]], result: pipeline_pb2.InputSpec.Channel, + property_predicate: Optional[PropertyPredicate] = None, ): """Adds contexts to the channel.""" for context_type, context_value in context_types_and_names: ctx = result.context_queries.add() ctx.type.name = context_type - ctx.name.CopyFrom(context_value) + if context_value: + ctx.name.CopyFrom(context_value) + if property_predicate: + ctx.property_predicate.CopyFrom(property_predicate) def _compile_channel_pb( @@ -175,6 +181,67 @@ def _compile_channel_pb( result.output_key = output_key +def _compile_run_context_predicate( + run_context_predicates: Sequence[tuple[str, metadata_store_pb2.Value]], + result_input_channel: pipeline_pb2.InputSpec.Channel, +): + """Compile run context property predicates into InputSpec.Channel.""" + if not run_context_predicates: + return + + predicates = [] + for run_context_predicate in run_context_predicates: + predicates.append( + PropertyPredicate( + value_comparator=PropertyPredicate.ValueComparator( + property_name=run_context_predicate[0], + op=PropertyPredicate.ValueComparator.Op.EQ, + target_value=pipeline_pb2.Value( + field_value=run_context_predicate[1] + ), + is_custom_property=True, + ) + ) + ) + + if not predicates: + return + elif len(predicates) == 1: + _compile_channel_pb_contexts( + [( + constants.PIPELINE_RUN_CONTEXT_TYPE_NAME, + _get_tfx_value(''), + )], + result_input_channel, + predicates[0], + ) + else: + binary_operator_predicate = PropertyPredicate( + binary_logical_operator=PropertyPredicate.BinaryLogicalOperator( + op=PropertyPredicate.BinaryLogicalOperator.LogicalOp.AND, + lhs=predicates[0], + rhs=predicates[1], + ) + ) + for i in range(2, len(predicates)): + binary_operator_predicate = PropertyPredicate( + binary_logical_operator=PropertyPredicate.BinaryLogicalOperator( + op=PropertyPredicate.BinaryLogicalOperator.LogicalOp.AND, + lhs=binary_operator_predicate, + rhs=predicates[i], + ) + ) + + _compile_channel_pb_contexts( + [( + constants.PIPELINE_RUN_CONTEXT_TYPE_NAME, + _get_tfx_value(''), + )], + result_input_channel, + binary_operator_predicate, + ) + + def _compile_input_spec( *, pipeline_ctx: compiler_context.PipelineContext, @@ -206,7 +273,7 @@ def _compile_input_spec( # from the same resolver function output. if not hidden: # Overwrite hidden = False even for already compiled channel, this is - # because we don't know the input should truely be hidden until the + # because we don't know the input should truly be hidden until the # channel turns out not to be. result.inputs[input_key].hidden = False return @@ -249,6 +316,11 @@ def _compile_input_spec( result_input_channel, ) + if channel.run_context_predicates: + _compile_run_context_predicate( + channel.run_context_predicates, result_input_channel + ) + if pipeline_ctx.pipeline.platform_config: project_config = ( pipeline_ctx.pipeline.platform_config.project_platform_config diff --git a/tfx/dsl/compiler/node_inputs_compiler_test.py b/tfx/dsl/compiler/node_inputs_compiler_test.py index 5bb2844e4f0..f554bb38269 100644 --- a/tfx/dsl/compiler/node_inputs_compiler_test.py +++ b/tfx/dsl/compiler/node_inputs_compiler_test.py @@ -37,6 +37,7 @@ from tfx.types import standard_artifacts from google.protobuf import text_format +from ml_metadata.proto import metadata_store_pb2 class DummyArtifact(types.Artifact): @@ -292,6 +293,256 @@ def testCompileInputGraph(self): ctx, node, channel, result) self.assertEqual(input_graph_id, second_input_graph_id) + def testCompilePropertyPredicateForTags(self): + with self.subTest('zero tag'): + consumer = DummyNode( + 'MyConsumer', + inputs={ + 'input_key': channel_types.ExternalPipelineChannel( + artifact_type=DummyArtifact, + owner='MyProducer', + pipeline_name='pipeline_name', + producer_component_id='producer_component_id', + output_key='z', + run_context_predicates=[], + ) + }, + ) + result = self._compile_node_inputs(consumer, components=[consumer]) + self.assertLen(result.inputs['input_key'].channels, 1) + self.assertProtoEquals( + """ + context_queries { + type { + name: "pipeline" + } + name { + field_value { + string_value: "pipeline_name" + } + } + } + context_queries { + type { + name: "node" + } + name { + field_value { + string_value: "pipeline_name.producer_component_id" + } + } + } + artifact_query { + type { + name: "Dummy" + } + } + output_key: "z" + metadata_connection_config { + [type.googleapis.com/tfx.orchestration.MLMDServiceConfig] { + owner: "MyProducer" + name: "pipeline_name" + } + } + """, + result.inputs['input_key'].channels[0], + ) + + with self.subTest('one tag'): + consumer = DummyNode( + 'MyConsumer', + inputs={ + 'input_key': channel_types.ExternalPipelineChannel( + artifact_type=DummyArtifact, + owner='MyProducer', + pipeline_name='pipeline_name', + producer_component_id='producer_component_id', + output_key='z', + run_context_predicates=[ + ('tag_1', metadata_store_pb2.Value(bool_value=True)) + ], + ) + }, + ) + + result = self._compile_node_inputs(consumer, components=[consumer]) + + self.assertLen(result.inputs['input_key'].channels, 1) + self.assertProtoEquals( + """ + context_queries { + type { + name: "pipeline" + } + name { + field_value { + string_value: "pipeline_name" + } + } + } + context_queries { + type { + name: "node" + } + name { + field_value { + string_value: "pipeline_name.producer_component_id" + } + } + } + context_queries { + type { + name: "pipeline_run" + } + name { + field_value { + string_value: "" + } + } + property_predicate { + value_comparator { + property_name: "tag_1" + target_value { + field_value { + bool_value: true + } + } + op: EQ + is_custom_property: true + } + } + } + artifact_query { + type { + name: "Dummy" + } + } + output_key: "z" + metadata_connection_config { + [type.googleapis.com/tfx.orchestration.MLMDServiceConfig] { + owner: "MyProducer" + name: "pipeline_name" + } + } + """, + result.inputs['input_key'].channels[0], + ) + + with self.subTest('three tags'): + consumer = DummyNode( + 'MyConsumer', + inputs={ + 'input_key': channel_types.ExternalPipelineChannel( + artifact_type=DummyArtifact, + owner='MyProducer', + pipeline_name='pipeline_name', + producer_component_id='producer_component_id', + output_key='z', + run_context_predicates=[ + ('tag_1', metadata_store_pb2.Value(bool_value=True)), + ('tag_2', metadata_store_pb2.Value(bool_value=True)), + ('tag_3', metadata_store_pb2.Value(bool_value=True)), + ], + ) + }, + ) + + result = self._compile_node_inputs(consumer, components=[consumer]) + self.assertLen(result.inputs['input_key'].channels, 1) + self.assertProtoEquals( + """ + context_queries { + type { + name: "pipeline" + } + name { + field_value { + string_value: "pipeline_name" + } + } + } + context_queries { + type { + name: "node" + } + name { + field_value { + string_value: "pipeline_name.producer_component_id" + } + } + } + context_queries { + type { + name: "pipeline_run" + } + name { + field_value { + string_value: "" + } + } + property_predicate { + binary_logical_operator { + op: AND + lhs { + binary_logical_operator { + op: AND + lhs { + value_comparator { + property_name: "tag_1" + target_value { + field_value { + bool_value: true + } + } + op: EQ + is_custom_property: true + } + } + rhs { + value_comparator { + property_name: "tag_2" + target_value { + field_value { + bool_value: true + } + } + op: EQ + is_custom_property: true + } + } + } + } + rhs { + value_comparator { + property_name: "tag_3" + target_value { + field_value { + bool_value: true + } + } + op: EQ + is_custom_property: true + } + } + } + } + } + artifact_query { + type { + name: "Dummy" + } + } + output_key: "z" + metadata_connection_config { + [type.googleapis.com/tfx.orchestration.MLMDServiceConfig] { + owner: "MyProducer" + name: "pipeline_name" + } + } + """, + result.inputs['input_key'].channels[0], + ) + def testCompileInputGraphRef(self): with dummy_artifact_list.given_output_type(DummyArtifact): x1 = dummy_artifact_list() diff --git a/tfx/dsl/compiler/testdata/consumer_pipeline_with_tags.py b/tfx/dsl/compiler/testdata/consumer_pipeline_with_tags.py new file mode 100644 index 00000000000..de4b48ce51e --- /dev/null +++ b/tfx/dsl/compiler/testdata/consumer_pipeline_with_tags.py @@ -0,0 +1,37 @@ +# Copyright 2022 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test pipeline for tfx.dsl.compiler.compiler.""" + +from tfx.components import StatisticsGen +from tfx.orchestration import pipeline +from tfx.types import channel_utils +from tfx.types import standard_artifacts + + +def create_test_pipeline(): + """Builds a consumer pipeline that gets artifacts from another project.""" + external_examples = channel_utils.external_pipeline_artifact_query( + artifact_type=standard_artifacts.Examples, + owner='owner', + pipeline_name='producer-pipeline', + producer_component_id='producer-component-id', + output_key='output-key', + pipeline_run_tags=['tag1', 'tag2', 'tag3'], + ) + + statistics_gen = StatisticsGen(examples=external_examples) + + return pipeline.Pipeline( + pipeline_name='consumer-pipeline', components=[statistics_gen] + ) diff --git a/tfx/dsl/compiler/testdata/consumer_pipeline_with_tags_input_v2_ir.pbtxt b/tfx/dsl/compiler/testdata/consumer_pipeline_with_tags_input_v2_ir.pbtxt new file mode 100644 index 00000000000..008b581f4fc --- /dev/null +++ b/tfx/dsl/compiler/testdata/consumer_pipeline_with_tags_input_v2_ir.pbtxt @@ -0,0 +1,210 @@ +pipeline_info { + id: "consumer-pipeline" + } + nodes { + pipeline_node { + node_info { + type { + name: "tfx.components.statistics_gen.component.StatisticsGen" + base_type: PROCESS + } + id: "StatisticsGen" + } + contexts { + contexts { + type { + name: "pipeline" + } + name { + field_value { + string_value: "consumer-pipeline" + } + } + } + contexts { + type { + name: "pipeline_run" + } + name { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + contexts { + type { + name: "node" + } + name { + field_value { + string_value: "consumer-pipeline.StatisticsGen" + } + } + } + } + inputs { + inputs { + key: "examples" + value { + channels { + context_queries { + type { + name: "pipeline" + } + name { + field_value { + string_value: "producer-pipeline" + } + } + } + context_queries { + type { + name: "node" + } + name { + field_value { + string_value: "producer-pipeline.producer-component-id" + } + } + } + context_queries { + type { + name: "pipeline_run" + } + name { + field_value { + string_value: "" + } + } + property_predicate { + binary_logical_operator { + op: AND + lhs { + binary_logical_operator { + op: AND + lhs { + value_comparator { + property_name: "__tag_tag1__" + target_value { + field_value { + bool_value: true + } + } + op: EQ + is_custom_property: true + } + } + rhs { + value_comparator { + property_name: "__tag_tag2__" + target_value { + field_value { + bool_value: true + } + } + op: EQ + is_custom_property: true + } + } + } + } + rhs { + value_comparator { + property_name: "__tag_tag3__" + target_value { + field_value { + bool_value: true + } + } + op: EQ + is_custom_property: true + } + } + } + } + } + artifact_query { + type { + name: "Examples" + base_type: DATASET + } + } + output_key: "output-key" + metadata_connection_config { + [type.googleapis.com/tfx.orchestration.MLMDServiceConfig] { + owner: "owner" + name: "producer-pipeline" + } + } + } + min_count: 1 + } + } + } + outputs { + outputs { + key: "statistics" + value { + artifact_spec { + type { + name: "ExampleStatistics" + properties { + key: "span" + value: INT + } + properties { + key: "split_names" + value: STRING + } + base_type: STATISTICS + } + } + } + } + } + parameters { + parameters { + key: "exclude_splits" + value { + field_value { + string_value: "[]" + } + } + } + } + execution_options { + caching_options { + } + } + } + } + runtime_spec { + pipeline_root { + runtime_parameter { + name: "pipeline-root" + type: STRING + } + } + pipeline_run_id { + runtime_parameter { + name: "pipeline-run-id" + type: STRING + } + } + } + execution_mode: SYNC + deployment_config { + [type.googleapis.com/tfx.orchestration.IntermediateDeploymentConfig] { + executor_specs { + key: "StatisticsGen" + value { + [type.googleapis.com/tfx.orchestration.executable_spec.BeamExecutableSpec] { + python_executor_spec { + class_path: "tfx.components.statistics_gen.executor.Executor" + } + } + } + } + } + } \ No newline at end of file diff --git a/tfx/types/channel.py b/tfx/types/channel.py index 9c79ea7e4b7..a1e980836f4 100644 --- a/tfx/types/channel.py +++ b/tfx/types/channel.py @@ -722,6 +722,9 @@ def __init__( producer_component_id: str, output_key: str, pipeline_run_id: str = '', + run_context_predicates: Sequence[ + tuple[str, metadata_store_pb2.Value] + ] = (), ): """Initialization of ExternalPipelineChannel. @@ -733,6 +736,8 @@ def __init__( output_key: The output key when producer component produces the artifacts in this Channel. pipeline_run_id: (Optional) Pipeline run id the artifacts belong to. + run_context_predicates: (Optional) A list of run context property + predicates to filter run contexts. """ super().__init__(type=artifact_type) self.owner = owner @@ -740,6 +745,7 @@ def __init__( self.producer_component_id = producer_component_id self.output_key = output_key self.pipeline_run_id = pipeline_run_id + self.run_context_predicates = run_context_predicates def get_data_dependent_node_ids(self) -> Set[str]: return set() @@ -751,7 +757,8 @@ def __repr__(self) -> str: f'pipeline_name={self.pipeline_name}, ' f'producer_component_id={self.producer_component_id}, ' f'output_key={self.output_key}, ' - f'pipeline_run_id={self.pipeline_run_id})' + f'pipeline_run_id={self.pipeline_run_id}), ' + f'run_context_predicates={self.run_context_predicates}' ) diff --git a/tfx/types/channel_utils.py b/tfx/types/channel_utils.py index 37125538332..b9240cc1bd6 100644 --- a/tfx/types/channel_utils.py +++ b/tfx/types/channel_utils.py @@ -33,6 +33,8 @@ from tfx.types import artifact from tfx.types import channel +from ml_metadata.proto import metadata_store_pb2 + class ChannelForTesting(channel.BaseChannel): """Dummy channel for testing.""" @@ -149,6 +151,7 @@ def external_pipeline_artifact_query( producer_component_id: str, output_key: str, pipeline_run_id: str = '', + pipeline_run_tags: Sequence[str] = (), ) -> channel.ExternalPipelineChannel: """Helper function to construct a query to get artifacts from an external pipeline. @@ -160,16 +163,37 @@ def external_pipeline_artifact_query( output_key: The output key when producer component produces the artifacts in this Channel. pipeline_run_id: (Optional) Pipeline run id the artifacts belong to. + pipeline_run_tags: (Optional) A list of tags the artifacts belong to. It is + an AND relationship between tags. For example, if tags=['tag1', 'tag2'], + then only artifacts belonging to the run with both 'tag1' and 'tag2' will + be returned. Only one of pipeline_run_id and pipeline_run_tags can be set. Returns: channel.ExternalPipelineChannel instance. Raises: - ValueError, if owner or pipeline_name is missing. + ValueError, if owner or pipeline_name is missing, or both pipeline_run_id + and pipeline_run_tags are set. """ if not owner or not pipeline_name: raise ValueError('owner or pipeline_name is missing.') + if pipeline_run_id and pipeline_run_tags: + raise ValueError( + 'pipeline_run_id and pipeline_run_tags cannot be both set.' + ) + + run_context_predicates = [] + for tag in pipeline_run_tags: + # TODO(b/264728226): Find a better way to construct the tag name that used + # in MLMD. Tag names that used in MLMD are constructed in tflex_mlmd_api.py, + # but it is not visible in this file. + mlmd_store_tag = '__tag_' + tag + '__' + run_context_predicates.append(( + mlmd_store_tag, + metadata_store_pb2.Value(bool_value=True), + )) + return channel.ExternalPipelineChannel( artifact_type=artifact_type, owner=owner, @@ -177,6 +201,7 @@ def external_pipeline_artifact_query( producer_component_id=producer_component_id, output_key=output_key, pipeline_run_id=pipeline_run_id, + run_context_predicates=run_context_predicates, )