Skip to content

Commit

Permalink
Cache subpipeline executions for partial run
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 610557733
  • Loading branch information
kmonte authored and tfx-copybara committed Feb 27, 2024
1 parent eea139a commit a900891
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 3 deletions.
50 changes: 50 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Pipeline-level operations."""

import collections
import contextlib
import copy
import dataclasses
Expand Down Expand Up @@ -184,6 +185,55 @@ def initiate_pipeline_start(
raise status_lib.StatusNotOkError(
code=status_lib.Code.INVALID_ARGUMENT, message=str(e)
)
else:
# Find all subpipelines in the parent pipeline, which we are caching.
to_process = collections.deque([])
for node in pipeline.nodes:
# Only add to processing queue if it's a subpipeline that we are going
# to cache. For subpipelines, the begin node's (nodes[0]) execution
# options repersent the subpipeline's execution options.
if node.WhichOneof(
'node'
) == 'sub_pipeline' and partial_run_utils.should_attempt_to_reuse_artifact(
node.sub_pipeline.nodes[0].pipeline_node.execution_options
):
to_process.append(node.sub_pipeline)
cached_subpipelines = []
while to_process:
subpipeline = to_process.popleft()
cached_subpipelines.append(subpipeline)
to_process.extend(
node.sub_pipeline
for node in subpipeline.nodes
if node.WhichOneof('node') == 'sub_pipeline'
)
logging.info(
'Found subpipelines: %s',
[s.pipeline_info.id for s in cached_subpipelines],
)
# Add a new pipeline run for every subpipeline we are going to cache in
# the partial run.
for subpipeline in cached_subpipelines:
reused_subpipeline_view = _load_reused_pipeline_view(
mlmd_handle, subpipeline, partial_run_option.snapshot_settings
)
# TODO: b/323912217 - Support putting multiple subpipeline executions
# into MLMD to handle the ForEach case.
with pstate.PipelineState.new(
mlmd_handle,
subpipeline,
pipeline_run_metadata,
reused_subpipeline_view,
) as subpipeline_state:
# TODO: b/320535460 - The new pipeline run should not be stopped if
# there are still nodes to run in it.
logging.info('Subpipeline execution cached for partial run.')
subpipeline_state.initiate_stop(
status_lib.Status(
code=status_lib.Code.OK,
message='Subpipeline execution cached for partial run.',
)
)
if pipeline.runtime_spec.HasField('snapshot_settings'):
try:
base_run_id = (
Expand Down
75 changes: 75 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,81 @@ def test_initiate_pipeline_start_with_partial_run(self, mock_snapshot):
mock_snapshot.assert_called_once()
self.assertEqual(expected_pipeline, pipeline_state.pipeline)

@parameterized.named_parameters(
dict(
testcase_name='cache_subpipeline',
run_subpipeline=False,
),
dict(
testcase_name='run_subpipeline',
run_subpipeline=True,
),
)
@mock.patch.object(partial_run_utils, 'snapshot')
def test_initiate_pipeline_start_with_partial_run_and_subpipeline(
self, mock_snapshot, run_subpipeline
):
with self._mlmd_connection as m:
pipeline = test_sync_pipeline.create_pipeline_with_subpipeline()
runtime_parameter_utils.substitute_runtime_parameter(
pipeline,
{
constants.PIPELINE_ROOT_PARAMETER_NAME: '/my/pipeline/root',
constants.PIPELINE_RUN_ID_PARAMETER_NAME: 'run-0123',
},
)

expected_pipeline = copy.deepcopy(pipeline)
example_gen = expected_pipeline.nodes[0].pipeline_node
subpipeline = expected_pipeline.nodes[1].sub_pipeline
subpipeline_begin = subpipeline.nodes[0].pipeline_node
transform = expected_pipeline.nodes[2].pipeline_node
partial_run_utils.set_latest_pipeline_run_strategy(
expected_pipeline.runtime_spec.snapshot_settings
)

skip = pipeline_pb2.NodeExecutionOptions.Skip(
reuse_artifacts_mode=pipeline_pb2.NodeExecutionOptions.Skip.REQUIRED
)
run = pipeline_pb2.NodeExecutionOptions.Run(
perform_snapshot=True, depends_on_snapshot=True
)
example_gen.execution_options.skip.CopyFrom(skip)

if run_subpipeline:
subpipeline_begin.execution_options.run.CopyFrom(run)
transform.execution_options.run.depends_on_snapshot = True
else:
subpipeline_begin.execution_options.skip.CopyFrom(skip)
transform.execution_options.run.CopyFrom(run)

partial_run_option = pipeline_pb2.PartialRun(
from_nodes=['sub-pipeline'] if run_subpipeline else ['my_transform'],
snapshot_settings=partial_run_utils.latest_pipeline_snapshot_settings(),
)
with pipeline_ops.initiate_pipeline_start(
m, pipeline, partial_run_option=partial_run_option
) as pipeline_state:
mock_snapshot.assert_called_once()
self.assertProtoEquals(expected_pipeline, pipeline_state.pipeline)

if run_subpipeline:
# If the subpipeline should be run then we should not have pre-loaded a
# run for it.
with self.assertRaises(status_lib.StatusNotOkError):
pstate.PipelineState.load_run(
m, 'sub-pipeline', 'sub-pipeline_run-0123'
)
else:
# Skipped subpipelines should have a run injected so their nodes are
# properly marked as cached.
with pstate.PipelineState.load_run(
m, 'sub-pipeline', 'sub-pipeline_run-0123'
) as subpipeline_state:
self.assertEqual(
subpipeline_state.stop_initiated_reason().code, status_lib.Code.OK
)

@mock.patch.object(partial_run_utils, 'snapshot')
def test_partial_run_with_previously_failed_nodes(self, mock_snapshot):
with self._mlmd_connection as m:
Expand Down
7 changes: 4 additions & 3 deletions tfx/orchestration/portable/partial_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def snapshot(mlmd_handle: metadata.Metadata,
"""
# Avoid unnecessary snapshotting step if no node needs to reuse any artifacts.
if not any(
_should_attempt_to_reuse_artifact(node.pipeline_node.execution_options)
should_attempt_to_reuse_artifact(node.pipeline_node.execution_options)
for node in pipeline.nodes):
return

Expand Down Expand Up @@ -452,8 +452,9 @@ def _get_validated_new_run_id(pipeline: pipeline_pb2.Pipeline,
return str(inferred_new_run_id or new_run_id)


def _should_attempt_to_reuse_artifact(
def should_attempt_to_reuse_artifact(
execution_options: pipeline_pb2.NodeExecutionOptions):
"""Returns whether artifacts should be reused for the these execution options."""
return execution_options.HasField('skip') and (
execution_options.skip.reuse_artifacts or
execution_options.skip.reuse_artifacts_mode == _REUSE_ARTIFACT_OPTIONAL or
Expand Down Expand Up @@ -512,7 +513,7 @@ def _reuse_pipeline_run_artifacts(
reuse_nodes = [
node
for node in node_proto_view.get_view_for_all_in(marked_pipeline)
if _should_attempt_to_reuse_artifact(node.execution_options)
if should_attempt_to_reuse_artifact(node.execution_options)
]
logging.info(
'Reusing nodes: %s', [n.node_info.id for n in reuse_nodes]
Expand Down

0 comments on commit a900891

Please sign in to comment.