Skip to content

Commit

Permalink
Add subpipeline_utils to contain utils for orchestrating subpipelines
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 625103692
  • Loading branch information
kmonte authored and tfx-copybara committed Apr 17, 2024
1 parent 0313e7c commit 9ccee82
Show file tree
Hide file tree
Showing 10 changed files with 185 additions and 584 deletions.
87 changes: 68 additions & 19 deletions tfx/dsl/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from tfx.dsl.compiler import compiler_context
from tfx.dsl.compiler import compiler_utils
from tfx.dsl.compiler import constants
from tfx.dsl.compiler import node_contexts_compiler
from tfx.dsl.compiler import node_execution_options_utils
from tfx.dsl.compiler import node_inputs_compiler
from tfx.dsl.components.base import base_component
Expand Down Expand Up @@ -57,12 +56,7 @@ def _compile_pipeline_begin_node(

# Step 2: Node Context
# Inner pipeline's contexts.
node.contexts.CopyFrom(
node_contexts_compiler.compile_node_contexts(
pipeline_ctx,
node.node_info.id,
)
)
_set_node_context(node, pipeline_ctx)

# Step 3: Node inputs
# Pipeline node inputs are stored as the inputs of the PipelineBegin node.
Expand Down Expand Up @@ -127,12 +121,7 @@ def _compile_pipeline_end_node(

# Step 2: Node Context
# Inner pipeline's contexts.
node.contexts.CopyFrom(
node_contexts_compiler.compile_node_contexts(
pipeline_ctx,
node.node_info.id,
)
)
_set_node_context(node, pipeline_ctx)

# Step 3: Node inputs
node_inputs_compiler.compile_node_inputs(
Expand Down Expand Up @@ -205,12 +194,7 @@ def _compile_node(
node.node_info.id = tfx_node.id

# Step 2: Node Context
node.contexts.CopyFrom(
node_contexts_compiler.compile_node_contexts(
pipeline_ctx,
node.node_info.id,
)
)
_set_node_context(node, pipeline_ctx)

# Step 3: Node inputs
node_inputs_compiler.compile_node_inputs(
Expand Down Expand Up @@ -402,6 +386,71 @@ def _validate_pipeline(tfx_pipeline: pipeline.Pipeline,
raise ValueError("Subpipeline has to be Sync execution mode.")


def _set_node_context(node: pipeline_pb2.PipelineNode,
pipeline_ctx: compiler_context.PipelineContext):
"""Compiles the node contexts of a pipeline node."""
# Context for the pipeline, across pipeline runs.
pipeline_context_pb = node.contexts.contexts.add()
pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME
pipeline_context_pb.name.field_value.string_value = (
pipeline_ctx.pipeline_info.pipeline_context_name)

# Context for the current pipeline run.
if pipeline_ctx.is_sync_mode:
pipeline_run_context_pb = node.contexts.contexts.add()
pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME
# TODO(kennethyang): Miragte pipeline run id to structural_runtime_parameter
# To keep existing IR textprotos used in tests unchanged, we only use
# structural_runtime_parameter for subpipelines. After the subpipeline being
# implemented, we will need to migrate normal pipelines to
# structural_runtime_parameter as well for consistency. Similar for below.
if pipeline_ctx.is_subpipeline:
compiler_utils.set_structural_runtime_parameter_pb(
pipeline_run_context_pb.name.structural_runtime_parameter, [
f"{pipeline_ctx.pipeline_info.pipeline_context_name}_",
(constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)
])
else:
compiler_utils.set_runtime_parameter_pb(
pipeline_run_context_pb.name.runtime_parameter,
constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)

# Contexts inherited from the parent pipelines.
for i, parent_pipeline in enumerate(pipeline_ctx.parent_pipelines[::-1]):
parent_pipeline_context_pb = node.contexts.contexts.add()
parent_pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME
parent_pipeline_context_pb.name.field_value.string_value = (
parent_pipeline.pipeline_info.pipeline_context_name)

if parent_pipeline.execution_mode == pipeline.ExecutionMode.SYNC:
pipeline_run_context_pb = node.contexts.contexts.add()
pipeline_run_context_pb.type.name = (
constants.PIPELINE_RUN_CONTEXT_TYPE_NAME)

# TODO(kennethyang): Miragte pipeline run id to structural runtime
# parameter for the similar reason mentioned above.
# Use structural runtime parameter to represent pipeline_run_id except
# for the root level pipeline, for backward compatibility.
if i == len(pipeline_ctx.parent_pipelines) - 1:
compiler_utils.set_runtime_parameter_pb(
pipeline_run_context_pb.name.runtime_parameter,
constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)
else:
compiler_utils.set_structural_runtime_parameter_pb(
pipeline_run_context_pb.name.structural_runtime_parameter, [
f"{parent_pipeline.pipeline_info.pipeline_context_name}_",
(constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)
])

# Context for the node, across pipeline runs.
node_context_pb = node.contexts.contexts.add()
node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME
node_context_pb.name.field_value.string_value = (
compiler_utils.node_context_name(
pipeline_ctx.pipeline_info.pipeline_context_name,
node.node_info.id))


def _set_node_outputs(node: pipeline_pb2.PipelineNode,
tfx_node_outputs: Dict[str, types.Channel]):
"""Compiles the outputs of a pipeline node."""
Expand Down
2 changes: 0 additions & 2 deletions tfx/dsl/compiler/compiler_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ def __init__(self,
# Mapping from Channel object to compiled Channel proto.
self.channels = dict()

self.node_context_protos_cache: dict[str, pipeline_pb2.NodeContexts] = {}

# Node ID -> NodeContext
self._node_contexts: Dict[str, NodeContext] = {}

Expand Down
108 changes: 0 additions & 108 deletions tfx/dsl/compiler/node_contexts_compiler.py

This file was deleted.

157 changes: 0 additions & 157 deletions tfx/dsl/compiler/node_contexts_compiler_test.py

This file was deleted.

Loading

0 comments on commit 9ccee82

Please sign in to comment.