Skip to content

Commit

Permalink
Add tag to ExternalPipelineChannel so we can get artifacts by tags.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 630465771
  • Loading branch information
tfx-copybara committed May 24, 2024
1 parent fb06979 commit e86e1ee
Show file tree
Hide file tree
Showing 7 changed files with 610 additions and 6 deletions.
2 changes: 2 additions & 0 deletions tfx/dsl/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from tfx.dsl.compiler.testdata import conditional_pipeline
from tfx.dsl.compiler.testdata import consumer_pipeline
from tfx.dsl.compiler.testdata import consumer_pipeline_different_project
from tfx.dsl.compiler.testdata import consumer_pipeline_with_tags
from tfx.dsl.compiler.testdata import dynamic_exec_properties_pipeline
from tfx.dsl.compiler.testdata import external_artifacts_pipeline
from tfx.dsl.compiler.testdata import foreach_pipeline
Expand Down Expand Up @@ -143,6 +144,7 @@ def _get_pipeline_ir(self, filename: str) -> pipeline_pb2.Pipeline:
consumer_pipeline,
external_artifacts_pipeline,
consumer_pipeline_different_project,
consumer_pipeline_with_tags,
])
)
def testCompile(
Expand Down
80 changes: 76 additions & 4 deletions tfx/dsl/compiler/node_inputs_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.
"""Compiler submodule specialized for NodeInputs."""

from collections.abc import Iterable
from typing import Type, cast
from collections.abc import Iterable, Sequence
from typing import Optional, Type, cast

from tfx import types
from tfx.dsl.compiler import compiler_context
Expand All @@ -41,6 +41,8 @@

from ml_metadata.proto import metadata_store_pb2

PropertyPredicate = pipeline_pb2.PropertyPredicate


def _get_tfx_value(value: str) -> pipeline_pb2.Value:
"""Returns a TFX Value containing the provided string."""
Expand Down Expand Up @@ -137,12 +139,16 @@ def compile_op_node(op_node: resolver_op.OpNode):
def _compile_channel_pb_contexts(
context_types_and_names: Iterable[tuple[str, pipeline_pb2.Value]],
result: pipeline_pb2.InputSpec.Channel,
property_predicate: Optional[PropertyPredicate] = None,
):
"""Adds contexts to the channel."""
for context_type, context_value in context_types_and_names:
ctx = result.context_queries.add()
ctx.type.name = context_type
ctx.name.CopyFrom(context_value)
if context_value:
ctx.name.CopyFrom(context_value)
if property_predicate:
ctx.property_predicate.CopyFrom(property_predicate)


def _compile_channel_pb(
Expand Down Expand Up @@ -175,6 +181,67 @@ def _compile_channel_pb(
result.output_key = output_key


def _compile_run_context_predicate(
run_context_predicates: Sequence[tuple[str, metadata_store_pb2.Value]],
result_input_channel: pipeline_pb2.InputSpec.Channel,
):
"""Compile run context property predicates into InputSpec.Channel."""
if not run_context_predicates:
return

predicates = []
for run_context_predicate in run_context_predicates:
predicates.append(
PropertyPredicate(
value_comparator=PropertyPredicate.ValueComparator(
property_name=run_context_predicate[0],
op=PropertyPredicate.ValueComparator.Op.EQ,
target_value=pipeline_pb2.Value(
field_value=run_context_predicate[1]
),
is_custom_property=True,
)
)
)

if not predicates:
return
elif len(predicates) == 1:
_compile_channel_pb_contexts(
[(
constants.PIPELINE_RUN_CONTEXT_TYPE_NAME,
_get_tfx_value(''),
)],
result_input_channel,
predicates[0],
)
else:
binary_operator_predicate = PropertyPredicate(
binary_logical_operator=PropertyPredicate.BinaryLogicalOperator(
op=PropertyPredicate.BinaryLogicalOperator.LogicalOp.AND,
lhs=predicates[0],
rhs=predicates[1],
)
)
for i in range(2, len(predicates)):
binary_operator_predicate = PropertyPredicate(
binary_logical_operator=PropertyPredicate.BinaryLogicalOperator(
op=PropertyPredicate.BinaryLogicalOperator.LogicalOp.AND,
lhs=binary_operator_predicate,
rhs=predicates[i],
)
)

_compile_channel_pb_contexts(
[(
constants.PIPELINE_RUN_CONTEXT_TYPE_NAME,
_get_tfx_value(''),
)],
result_input_channel,
binary_operator_predicate,
)


def _compile_input_spec(
*,
pipeline_ctx: compiler_context.PipelineContext,
Expand Down Expand Up @@ -206,7 +273,7 @@ def _compile_input_spec(
# from the same resolver function output.
if not hidden:
# Overwrite hidden = False even for already compiled channel, this is
# because we don't know the input should truely be hidden until the
# because we don't know the input should truly be hidden until the
# channel turns out not to be.
result.inputs[input_key].hidden = False
return
Expand Down Expand Up @@ -249,6 +316,11 @@ def _compile_input_spec(
result_input_channel,
)

if channel.run_context_predicates:
_compile_run_context_predicate(
channel.run_context_predicates, result_input_channel
)

if pipeline_ctx.pipeline.platform_config:
project_config = (
pipeline_ctx.pipeline.platform_config.project_platform_config
Expand Down
Loading

0 comments on commit e86e1ee

Please sign in to comment.