Skip to content

Commit

Permalink
Fix as_optional()
Browse files Browse the repository at this point in the history
Because we use `id` as a hash and  as_optional() creates a new object then this check [1] will not pass, and we'd instead go and fall through to [2], which does not add the pipeline run context.

PiperOrigin-RevId: 625736342
  • Loading branch information
kmonte authored and tfx-copybara committed Apr 17, 2024
1 parent 9332479 commit 0313e7c
Show file tree
Hide file tree
Showing 8 changed files with 584 additions and 84 deletions.
87 changes: 19 additions & 68 deletions tfx/dsl/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tfx.dsl.compiler import compiler_context
from tfx.dsl.compiler import compiler_utils
from tfx.dsl.compiler import constants
from tfx.dsl.compiler import node_contexts_compiler
from tfx.dsl.compiler import node_execution_options_utils
from tfx.dsl.compiler import node_inputs_compiler
from tfx.dsl.components.base import base_component
Expand Down Expand Up @@ -56,7 +57,12 @@ def _compile_pipeline_begin_node(

# Step 2: Node Context
# Inner pipeline's contexts.
_set_node_context(node, pipeline_ctx)
node.contexts.CopyFrom(
node_contexts_compiler.compile_node_contexts(
pipeline_ctx,
node.node_info.id,
)
)

# Step 3: Node inputs
# Pipeline node inputs are stored as the inputs of the PipelineBegin node.
Expand Down Expand Up @@ -121,7 +127,12 @@ def _compile_pipeline_end_node(

# Step 2: Node Context
# Inner pipeline's contexts.
_set_node_context(node, pipeline_ctx)
node.contexts.CopyFrom(
node_contexts_compiler.compile_node_contexts(
pipeline_ctx,
node.node_info.id,
)
)

# Step 3: Node inputs
node_inputs_compiler.compile_node_inputs(
Expand Down Expand Up @@ -194,7 +205,12 @@ def _compile_node(
node.node_info.id = tfx_node.id

# Step 2: Node Context
_set_node_context(node, pipeline_ctx)
node.contexts.CopyFrom(
node_contexts_compiler.compile_node_contexts(
pipeline_ctx,
node.node_info.id,
)
)

# Step 3: Node inputs
node_inputs_compiler.compile_node_inputs(
Expand Down Expand Up @@ -386,71 +402,6 @@ def _validate_pipeline(tfx_pipeline: pipeline.Pipeline,
raise ValueError("Subpipeline has to be Sync execution mode.")


def _set_node_context(node: pipeline_pb2.PipelineNode,
pipeline_ctx: compiler_context.PipelineContext):
"""Compiles the node contexts of a pipeline node."""
# Context for the pipeline, across pipeline runs.
pipeline_context_pb = node.contexts.contexts.add()
pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME
pipeline_context_pb.name.field_value.string_value = (
pipeline_ctx.pipeline_info.pipeline_context_name)

# Context for the current pipeline run.
if pipeline_ctx.is_sync_mode:
pipeline_run_context_pb = node.contexts.contexts.add()
pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME
# TODO(kennethyang): Miragte pipeline run id to structural_runtime_parameter
# To keep existing IR textprotos used in tests unchanged, we only use
# structural_runtime_parameter for subpipelines. After the subpipeline being
# implemented, we will need to migrate normal pipelines to
# structural_runtime_parameter as well for consistency. Similar for below.
if pipeline_ctx.is_subpipeline:
compiler_utils.set_structural_runtime_parameter_pb(
pipeline_run_context_pb.name.structural_runtime_parameter, [
f"{pipeline_ctx.pipeline_info.pipeline_context_name}_",
(constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)
])
else:
compiler_utils.set_runtime_parameter_pb(
pipeline_run_context_pb.name.runtime_parameter,
constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)

# Contexts inherited from the parent pipelines.
for i, parent_pipeline in enumerate(pipeline_ctx.parent_pipelines[::-1]):
parent_pipeline_context_pb = node.contexts.contexts.add()
parent_pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME
parent_pipeline_context_pb.name.field_value.string_value = (
parent_pipeline.pipeline_info.pipeline_context_name)

if parent_pipeline.execution_mode == pipeline.ExecutionMode.SYNC:
pipeline_run_context_pb = node.contexts.contexts.add()
pipeline_run_context_pb.type.name = (
constants.PIPELINE_RUN_CONTEXT_TYPE_NAME)

# TODO(kennethyang): Miragte pipeline run id to structural runtime
# parameter for the similar reason mentioned above.
# Use structural runtime parameter to represent pipeline_run_id except
# for the root level pipeline, for backward compatibility.
if i == len(pipeline_ctx.parent_pipelines) - 1:
compiler_utils.set_runtime_parameter_pb(
pipeline_run_context_pb.name.runtime_parameter,
constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)
else:
compiler_utils.set_structural_runtime_parameter_pb(
pipeline_run_context_pb.name.structural_runtime_parameter, [
f"{parent_pipeline.pipeline_info.pipeline_context_name}_",
(constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)
])

# Context for the node, across pipeline runs.
node_context_pb = node.contexts.contexts.add()
node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME
node_context_pb.name.field_value.string_value = (
compiler_utils.node_context_name(
pipeline_ctx.pipeline_info.pipeline_context_name,
node.node_info.id))


def _set_node_outputs(node: pipeline_pb2.PipelineNode,
tfx_node_outputs: Dict[str, types.Channel]):
"""Compiles the outputs of a pipeline node."""
Expand Down
2 changes: 2 additions & 0 deletions tfx/dsl/compiler/compiler_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(self,
# Mapping from Channel object to compiled Channel proto.
self.channels = dict()

self.node_context_protos_cache: dict[str, pipeline_pb2.NodeContexts] = {}

# Node ID -> NodeContext
self._node_contexts: Dict[str, NodeContext] = {}

Expand Down
108 changes: 108 additions & 0 deletions tfx/dsl/compiler/node_contexts_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compiles NodeContexts."""

from tfx.dsl.compiler import compiler_context
from tfx.dsl.compiler import compiler_utils
from tfx.dsl.compiler import constants
from tfx.orchestration import pipeline
from tfx.proto.orchestration import pipeline_pb2


def compile_node_contexts(
pipeline_ctx: compiler_context.PipelineContext,
node_id: str,
) -> pipeline_pb2.NodeContexts:
"""Compiles the node contexts of a pipeline node."""

if pipeline_ctx.pipeline_info is None:
return pipeline_pb2.NodeContexts()
if maybe_contexts := pipeline_ctx.node_context_protos_cache.get(node_id):
return maybe_contexts

node_contexts = pipeline_pb2.NodeContexts()
# Context for the pipeline, across pipeline runs.
pipeline_context_pb = node_contexts.contexts.add()
pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME
pipeline_context_pb.name.field_value.string_value = (
pipeline_ctx.pipeline_info.pipeline_context_name
)

# Context for the current pipeline run.
if pipeline_ctx.is_sync_mode:
pipeline_run_context_pb = node_contexts.contexts.add()
pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME
# TODO(kennethyang): Miragte pipeline run id to structural_runtime_parameter
# To keep existing IR textprotos used in tests unchanged, we only use
# structural_runtime_parameter for subpipelines. After the subpipeline being
# implemented, we will need to migrate normal pipelines to
# structural_runtime_parameter as well for consistency. Similar for below.
if pipeline_ctx.is_subpipeline:
compiler_utils.set_structural_runtime_parameter_pb(
pipeline_run_context_pb.name.structural_runtime_parameter,
[
f"{pipeline_ctx.pipeline_info.pipeline_context_name}_",
(constants.PIPELINE_RUN_ID_PARAMETER_NAME, str),
],
)
else:
compiler_utils.set_runtime_parameter_pb(
pipeline_run_context_pb.name.runtime_parameter,
constants.PIPELINE_RUN_ID_PARAMETER_NAME,
str,
)

# Contexts inherited from the parent pipelines.
for i, parent_pipeline in enumerate(pipeline_ctx.parent_pipelines[::-1]):
parent_pipeline_context_pb = node_contexts.contexts.add()
parent_pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME
parent_pipeline_context_pb.name.field_value.string_value = (
parent_pipeline.pipeline_info.pipeline_context_name
)

if parent_pipeline.execution_mode == pipeline.ExecutionMode.SYNC:
pipeline_run_context_pb = node_contexts.contexts.add()
pipeline_run_context_pb.type.name = (
constants.PIPELINE_RUN_CONTEXT_TYPE_NAME
)

# TODO(kennethyang): Miragte pipeline run id to structural runtime
# parameter for the similar reason mentioned above.
# Use structural runtime parameter to represent pipeline_run_id except
# for the root level pipeline, for backward compatibility.
if i == len(pipeline_ctx.parent_pipelines) - 1:
compiler_utils.set_runtime_parameter_pb(
pipeline_run_context_pb.name.runtime_parameter,
constants.PIPELINE_RUN_ID_PARAMETER_NAME,
str,
)
else:
compiler_utils.set_structural_runtime_parameter_pb(
pipeline_run_context_pb.name.structural_runtime_parameter,
[
f"{parent_pipeline.pipeline_info.pipeline_context_name}_",
(constants.PIPELINE_RUN_ID_PARAMETER_NAME, str),
],
)

# Context for the node, across pipeline runs.
node_context_pb = node_contexts.contexts.add()
node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME
node_context_pb.name.field_value.string_value = (
compiler_utils.node_context_name(
pipeline_ctx.pipeline_info.pipeline_context_name, node_id
)
)
pipeline_ctx.node_context_protos_cache[node_id] = node_contexts
return node_contexts
157 changes: 157 additions & 0 deletions tfx/dsl/compiler/node_contexts_compiler_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tfx.dsl.compiler.node_contexts_compiler."""

import tensorflow as tf
from tfx.dsl.compiler import compiler_context
from tfx.dsl.compiler import node_contexts_compiler
from tfx.orchestration import pipeline
from tfx.proto.orchestration import pipeline_pb2

from google.protobuf import text_format

_NODE_ID = 'test_node'
_PIPELINE_NAME = 'test_pipeline'


class NodeContextsCompilerTest(tf.test.TestCase):

def test_compile_node_contexts(self):
expected_node_contexts = text_format.Parse(
"""
contexts {
type {
name: "pipeline"
}
name {
field_value {
string_value: "test_pipeline"
}
}
}
contexts {
type {
name: "pipeline_run"
}
name {
runtime_parameter {
name: "pipeline-run-id"
type: STRING
}
}
}
contexts {
type {
name: "node"
}
name {
field_value {
string_value: "test_pipeline.test_node"
}
}
}
""",
pipeline_pb2.NodeContexts(),
)
self.assertProtoEquals(
node_contexts_compiler.compile_node_contexts(
compiler_context.PipelineContext(pipeline.Pipeline(_PIPELINE_NAME)),
_NODE_ID,
),
expected_node_contexts,
)

def test_compile_node_contexts_for_subpipeline(self):
parent_context = compiler_context.PipelineContext(
pipeline.Pipeline(_PIPELINE_NAME)
)
subpipeline_context = compiler_context.PipelineContext(
pipeline.Pipeline('subpipeline'), parent_context
)

expected_node_contexts = text_format.Parse(
"""
contexts {
type {
name: "pipeline"
}
name {
field_value {
string_value: "subpipeline"
}
}
}
contexts {
type {
name: "pipeline_run"
}
name {
structural_runtime_parameter {
parts {
constant_value: "subpipeline_"
}
parts {
runtime_parameter {
name: "pipeline-run-id"
type: STRING
}
}
}
}
}
contexts {
type {
name: "pipeline"
}
name {
field_value {
string_value: "test_pipeline"
}
}
}
contexts {
type {
name: "pipeline_run"
}
name {
runtime_parameter {
name: "pipeline-run-id"
type: STRING
}
}
}
contexts {
type {
name: "node"
}
name {
field_value {
string_value: "subpipeline.test_node"
}
}
}
""",
pipeline_pb2.NodeContexts(),
)
self.assertProtoEquals(
node_contexts_compiler.compile_node_contexts(
subpipeline_context,
_NODE_ID,
),
expected_node_contexts,
)


if __name__ == '__main__':
tf.test.main()
Loading

0 comments on commit 0313e7c

Please sign in to comment.