Skip to content

Commit

Permalink
Add subpipeline_utils to contain utils for orchestrating subpipelines
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 625756178
  • Loading branch information
kmonte authored and tfx-copybara committed Apr 17, 2024
1 parent 0313e7c commit ef4dd95
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 0 deletions.
54 changes: 54 additions & 0 deletions tfx/orchestration/subpipeline_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generic utilities for orchestrating subpipelines."""


from tfx.dsl.compiler import compiler_utils
from tfx.dsl.compiler import constants as compiler_constants
from tfx.orchestration import pipeline as dsl_pipeline
from tfx.proto.orchestration import pipeline_pb2

# This pipeline *only* exists so that we can correctly infer the correct node
# types for pipeline begin and end nodes, as the compiler uses a Python Pipeline
# object to generate the names.
# This pipeline *should not* be used otherwise.
_DUMMY_PIPELINE = dsl_pipeline.Pipeline(pipeline_name="UNUSED")


def is_subpipeline(pipeline: pipeline_pb2.Pipeline) -> bool:
"""Returns True if the pipeline is a subpipeline."""
nodes = pipeline.nodes
if len(nodes) < 2:
return False
maybe_begin_node = nodes[0]
maybe_end_node = nodes[-1]
if (
maybe_begin_node.WhichOneof("node") != "pipeline_node"
or maybe_begin_node.pipeline_node.node_info.id
!= f"{pipeline.pipeline_info.id}{compiler_constants.PIPELINE_BEGIN_NODE_SUFFIX}"
or maybe_begin_node.pipeline_node.node_info.type.name
!= compiler_utils.pipeline_begin_node_type_name(_DUMMY_PIPELINE)
):
return False
if (
maybe_end_node.WhichOneof("node") != "pipeline_node"
or maybe_end_node.pipeline_node.node_info.id
!= compiler_utils.pipeline_end_node_id_from_pipeline_id(
pipeline.pipeline_info.id
)
or maybe_end_node.pipeline_node.node_info.type.name
!= compiler_utils.pipeline_end_node_type_name(_DUMMY_PIPELINE)
):
return False
return True
47 changes: 47 additions & 0 deletions tfx/orchestration/subpipeline_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tfx.orchestration.subpipeline_utils."""

from absl.testing import absltest
from absl.testing import parameterized
from tfx.dsl.compiler import compiler
from tfx.orchestration import pipeline as dsl_pipeline
from tfx.orchestration import subpipeline_utils

_PIPELINE_NAME = 'test_pipeline'
_TEST_PIPELINE = dsl_pipeline.Pipeline(pipeline_name=_PIPELINE_NAME)


class SubpipelineUtilsTest(parameterized.TestCase):

def test_is_subpipeline_with_subpipeline(self):
subpipeline = dsl_pipeline.Pipeline(pipeline_name='subpipeline')
pipeline = dsl_pipeline.Pipeline(
pipeline_name=_PIPELINE_NAME, components=[subpipeline]
)
pipeline_ir = compiler.Compiler().compile(pipeline)
subpipeline_ir = pipeline_ir.nodes[0].sub_pipeline
self.assertTrue(subpipeline_utils.is_subpipeline(subpipeline_ir))

def test_is_subpipeline_with_parent_pipelines(self):
subpipeline = dsl_pipeline.Pipeline(pipeline_name='subpipeline')
pipeline = dsl_pipeline.Pipeline(
pipeline_name=_PIPELINE_NAME, components=[subpipeline]
)
pipeline_ir = compiler.Compiler().compile(pipeline)
self.assertFalse(subpipeline_utils.is_subpipeline(pipeline_ir))


if __name__ == '__main__':
absltest.main()

0 comments on commit ef4dd95

Please sign in to comment.