Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix as_optional() #6756

Merged
merged 1 commit into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 19 additions & 68 deletions tfx/dsl/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tfx.dsl.compiler import compiler_context
from tfx.dsl.compiler import compiler_utils
from tfx.dsl.compiler import constants
from tfx.dsl.compiler import node_contexts_compiler
from tfx.dsl.compiler import node_execution_options_utils
from tfx.dsl.compiler import node_inputs_compiler
from tfx.dsl.components.base import base_component
Expand Down Expand Up @@ -56,7 +57,12 @@ def _compile_pipeline_begin_node(

# Step 2: Node Context
# Inner pipeline's contexts.
_set_node_context(node, pipeline_ctx)
node.contexts.CopyFrom(
node_contexts_compiler.compile_node_contexts(
pipeline_ctx,
node.node_info.id,
)
)

# Step 3: Node inputs
# Pipeline node inputs are stored as the inputs of the PipelineBegin node.
Expand Down Expand Up @@ -121,7 +127,12 @@ def _compile_pipeline_end_node(

# Step 2: Node Context
# Inner pipeline's contexts.
_set_node_context(node, pipeline_ctx)
node.contexts.CopyFrom(
node_contexts_compiler.compile_node_contexts(
pipeline_ctx,
node.node_info.id,
)
)

# Step 3: Node inputs
node_inputs_compiler.compile_node_inputs(
Expand Down Expand Up @@ -194,7 +205,12 @@ def _compile_node(
node.node_info.id = tfx_node.id

# Step 2: Node Context
_set_node_context(node, pipeline_ctx)
node.contexts.CopyFrom(
node_contexts_compiler.compile_node_contexts(
pipeline_ctx,
node.node_info.id,
)
)

# Step 3: Node inputs
node_inputs_compiler.compile_node_inputs(
Expand Down Expand Up @@ -386,71 +402,6 @@ def _validate_pipeline(tfx_pipeline: pipeline.Pipeline,
raise ValueError("Subpipeline has to be Sync execution mode.")


def _set_node_context(node: pipeline_pb2.PipelineNode,
pipeline_ctx: compiler_context.PipelineContext):
"""Compiles the node contexts of a pipeline node."""
# Context for the pipeline, across pipeline runs.
pipeline_context_pb = node.contexts.contexts.add()
pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME
pipeline_context_pb.name.field_value.string_value = (
pipeline_ctx.pipeline_info.pipeline_context_name)

# Context for the current pipeline run.
if pipeline_ctx.is_sync_mode:
pipeline_run_context_pb = node.contexts.contexts.add()
pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME
# TODO(kennethyang): Miragte pipeline run id to structural_runtime_parameter
# To keep existing IR textprotos used in tests unchanged, we only use
# structural_runtime_parameter for subpipelines. After the subpipeline being
# implemented, we will need to migrate normal pipelines to
# structural_runtime_parameter as well for consistency. Similar for below.
if pipeline_ctx.is_subpipeline:
compiler_utils.set_structural_runtime_parameter_pb(
pipeline_run_context_pb.name.structural_runtime_parameter, [
f"{pipeline_ctx.pipeline_info.pipeline_context_name}_",
(constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)
])
else:
compiler_utils.set_runtime_parameter_pb(
pipeline_run_context_pb.name.runtime_parameter,
constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)

# Contexts inherited from the parent pipelines.
for i, parent_pipeline in enumerate(pipeline_ctx.parent_pipelines[::-1]):
parent_pipeline_context_pb = node.contexts.contexts.add()
parent_pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME
parent_pipeline_context_pb.name.field_value.string_value = (
parent_pipeline.pipeline_info.pipeline_context_name)

if parent_pipeline.execution_mode == pipeline.ExecutionMode.SYNC:
pipeline_run_context_pb = node.contexts.contexts.add()
pipeline_run_context_pb.type.name = (
constants.PIPELINE_RUN_CONTEXT_TYPE_NAME)

# TODO(kennethyang): Miragte pipeline run id to structural runtime
# parameter for the similar reason mentioned above.
# Use structural runtime parameter to represent pipeline_run_id except
# for the root level pipeline, for backward compatibility.
if i == len(pipeline_ctx.parent_pipelines) - 1:
compiler_utils.set_runtime_parameter_pb(
pipeline_run_context_pb.name.runtime_parameter,
constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)
else:
compiler_utils.set_structural_runtime_parameter_pb(
pipeline_run_context_pb.name.structural_runtime_parameter, [
f"{parent_pipeline.pipeline_info.pipeline_context_name}_",
(constants.PIPELINE_RUN_ID_PARAMETER_NAME, str)
])

# Context for the node, across pipeline runs.
node_context_pb = node.contexts.contexts.add()
node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME
node_context_pb.name.field_value.string_value = (
compiler_utils.node_context_name(
pipeline_ctx.pipeline_info.pipeline_context_name,
node.node_info.id))


def _set_node_outputs(node: pipeline_pb2.PipelineNode,
tfx_node_outputs: Dict[str, types.Channel]):
"""Compiles the outputs of a pipeline node."""
Expand Down
2 changes: 2 additions & 0 deletions tfx/dsl/compiler/compiler_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(self,
# Mapping from Channel object to compiled Channel proto.
self.channels = dict()

self.node_context_protos_cache: dict[str, pipeline_pb2.NodeContexts] = {}

# Node ID -> NodeContext
self._node_contexts: Dict[str, NodeContext] = {}

Expand Down
108 changes: 108 additions & 0 deletions tfx/dsl/compiler/node_contexts_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compiles NodeContexts."""

from tfx.dsl.compiler import compiler_context
from tfx.dsl.compiler import compiler_utils
from tfx.dsl.compiler import constants
from tfx.orchestration import pipeline
from tfx.proto.orchestration import pipeline_pb2


def compile_node_contexts(
pipeline_ctx: compiler_context.PipelineContext,
node_id: str,
) -> pipeline_pb2.NodeContexts:
"""Compiles the node contexts of a pipeline node."""

if pipeline_ctx.pipeline_info is None:
return pipeline_pb2.NodeContexts()
if maybe_contexts := pipeline_ctx.node_context_protos_cache.get(node_id):
return maybe_contexts

node_contexts = pipeline_pb2.NodeContexts()
# Context for the pipeline, across pipeline runs.
pipeline_context_pb = node_contexts.contexts.add()
pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME
pipeline_context_pb.name.field_value.string_value = (
pipeline_ctx.pipeline_info.pipeline_context_name
)

# Context for the current pipeline run.
if pipeline_ctx.is_sync_mode:
pipeline_run_context_pb = node_contexts.contexts.add()
pipeline_run_context_pb.type.name = constants.PIPELINE_RUN_CONTEXT_TYPE_NAME
# TODO(kennethyang): Miragte pipeline run id to structural_runtime_parameter
# To keep existing IR textprotos used in tests unchanged, we only use
# structural_runtime_parameter for subpipelines. After the subpipeline being
# implemented, we will need to migrate normal pipelines to
# structural_runtime_parameter as well for consistency. Similar for below.
if pipeline_ctx.is_subpipeline:
compiler_utils.set_structural_runtime_parameter_pb(
pipeline_run_context_pb.name.structural_runtime_parameter,
[
f"{pipeline_ctx.pipeline_info.pipeline_context_name}_",
(constants.PIPELINE_RUN_ID_PARAMETER_NAME, str),
],
)
else:
compiler_utils.set_runtime_parameter_pb(
pipeline_run_context_pb.name.runtime_parameter,
constants.PIPELINE_RUN_ID_PARAMETER_NAME,
str,
)

# Contexts inherited from the parent pipelines.
for i, parent_pipeline in enumerate(pipeline_ctx.parent_pipelines[::-1]):
parent_pipeline_context_pb = node_contexts.contexts.add()
parent_pipeline_context_pb.type.name = constants.PIPELINE_CONTEXT_TYPE_NAME
parent_pipeline_context_pb.name.field_value.string_value = (
parent_pipeline.pipeline_info.pipeline_context_name
)

if parent_pipeline.execution_mode == pipeline.ExecutionMode.SYNC:
pipeline_run_context_pb = node_contexts.contexts.add()
pipeline_run_context_pb.type.name = (
constants.PIPELINE_RUN_CONTEXT_TYPE_NAME
)

# TODO(kennethyang): Miragte pipeline run id to structural runtime
# parameter for the similar reason mentioned above.
# Use structural runtime parameter to represent pipeline_run_id except
# for the root level pipeline, for backward compatibility.
if i == len(pipeline_ctx.parent_pipelines) - 1:
compiler_utils.set_runtime_parameter_pb(
pipeline_run_context_pb.name.runtime_parameter,
constants.PIPELINE_RUN_ID_PARAMETER_NAME,
str,
)
else:
compiler_utils.set_structural_runtime_parameter_pb(
pipeline_run_context_pb.name.structural_runtime_parameter,
[
f"{parent_pipeline.pipeline_info.pipeline_context_name}_",
(constants.PIPELINE_RUN_ID_PARAMETER_NAME, str),
],
)

# Context for the node, across pipeline runs.
node_context_pb = node_contexts.contexts.add()
node_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME
node_context_pb.name.field_value.string_value = (
compiler_utils.node_context_name(
pipeline_ctx.pipeline_info.pipeline_context_name, node_id
)
)
pipeline_ctx.node_context_protos_cache[node_id] = node_contexts
return node_contexts
157 changes: 157 additions & 0 deletions tfx/dsl/compiler/node_contexts_compiler_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tfx.dsl.compiler.node_contexts_compiler."""

import tensorflow as tf
from tfx.dsl.compiler import compiler_context
from tfx.dsl.compiler import node_contexts_compiler
from tfx.orchestration import pipeline
from tfx.proto.orchestration import pipeline_pb2

from google.protobuf import text_format

_NODE_ID = 'test_node'
_PIPELINE_NAME = 'test_pipeline'


class NodeContextsCompilerTest(tf.test.TestCase):

def test_compile_node_contexts(self):
expected_node_contexts = text_format.Parse(
"""
contexts {
type {
name: "pipeline"
}
name {
field_value {
string_value: "test_pipeline"
}
}
}
contexts {
type {
name: "pipeline_run"
}
name {
runtime_parameter {
name: "pipeline-run-id"
type: STRING
}
}
}
contexts {
type {
name: "node"
}
name {
field_value {
string_value: "test_pipeline.test_node"
}
}
}
""",
pipeline_pb2.NodeContexts(),
)
self.assertProtoEquals(
node_contexts_compiler.compile_node_contexts(
compiler_context.PipelineContext(pipeline.Pipeline(_PIPELINE_NAME)),
_NODE_ID,
),
expected_node_contexts,
)

def test_compile_node_contexts_for_subpipeline(self):
parent_context = compiler_context.PipelineContext(
pipeline.Pipeline(_PIPELINE_NAME)
)
subpipeline_context = compiler_context.PipelineContext(
pipeline.Pipeline('subpipeline'), parent_context
)

expected_node_contexts = text_format.Parse(
"""
contexts {
type {
name: "pipeline"
}
name {
field_value {
string_value: "subpipeline"
}
}
}
contexts {
type {
name: "pipeline_run"
}
name {
structural_runtime_parameter {
parts {
constant_value: "subpipeline_"
}
parts {
runtime_parameter {
name: "pipeline-run-id"
type: STRING
}
}
}
}
}
contexts {
type {
name: "pipeline"
}
name {
field_value {
string_value: "test_pipeline"
}
}
}
contexts {
type {
name: "pipeline_run"
}
name {
runtime_parameter {
name: "pipeline-run-id"
type: STRING
}
}
}
contexts {
type {
name: "node"
}
name {
field_value {
string_value: "subpipeline.test_node"
}
}
}
""",
pipeline_pb2.NodeContexts(),
)
self.assertProtoEquals(
node_contexts_compiler.compile_node_contexts(
subpipeline_context,
_NODE_ID,
),
expected_node_contexts,
)


if __name__ == '__main__':
tf.test.main()
Loading