diff --git a/tfx/orchestration/experimental/core/pipeline_ops.py b/tfx/orchestration/experimental/core/pipeline_ops.py index ed2188deb4..b9a6a115b8 100644 --- a/tfx/orchestration/experimental/core/pipeline_ops.py +++ b/tfx/orchestration/experimental/core/pipeline_ops.py @@ -13,6 +13,7 @@ # limitations under the License. """Pipeline-level operations.""" +import collections import contextlib import copy import dataclasses @@ -184,6 +185,55 @@ def initiate_pipeline_start( raise status_lib.StatusNotOkError( code=status_lib.Code.INVALID_ARGUMENT, message=str(e) ) + else: + # Find all subpipelines in the parent pipeline, which we are caching. + to_process = collections.deque([]) + for node in pipeline.nodes: + # Only add to processing queue if it's a subpipeline that we are going + # to cache. For subpipelines, the begin node's (nodes[0]) execution + # options repersent the subpipeline's execution options. + if node.WhichOneof( + 'node' + ) == 'sub_pipeline' and partial_run_utils.should_attempt_to_reuse_artifact( + node.sub_pipeline.nodes[0].pipeline_node.execution_options + ): + to_process.append(node.sub_pipeline) + cached_subpipelines = [] + while to_process: + subpipeline = to_process.popleft() + cached_subpipelines.append(subpipeline) + to_process.extend( + node.sub_pipeline + for node in subpipeline.nodes + if node.WhichOneof('node') == 'sub_pipeline' + ) + logging.info( + 'Found subpipelines: %s', + [s.pipeline_info.id for s in cached_subpipelines], + ) + # Add a new pipeline run for every subpipeline we are going to cache in + # the partial run. + for subpipeline in cached_subpipelines: + reused_subpipeline_view = _load_reused_pipeline_view( + mlmd_handle, subpipeline, partial_run_option.snapshot_settings + ) + # TODO: b/323912217 - Support putting multiple subpipeline executions + # into MLMD to handle the ForEach case. + with pstate.PipelineState.new( + mlmd_handle, + subpipeline, + pipeline_run_metadata, + reused_subpipeline_view, + ) as subpipeline_state: + # TODO: b/320535460 - The new pipeline run should not be stopped if + # there are still nodes to run in it. + logging.info('Subpipeline execution cached for partial run.') + subpipeline_state.initiate_stop( + status_lib.Status( + code=status_lib.Code.OK, + message='Subpipeline execution cached for partial run.', + ) + ) if pipeline.runtime_spec.HasField('snapshot_settings'): try: base_run_id = ( diff --git a/tfx/orchestration/experimental/core/pipeline_ops_test.py b/tfx/orchestration/experimental/core/pipeline_ops_test.py index 137a12f3b2..9fbba2080c 100644 --- a/tfx/orchestration/experimental/core/pipeline_ops_test.py +++ b/tfx/orchestration/experimental/core/pipeline_ops_test.py @@ -773,6 +773,81 @@ def test_initiate_pipeline_start_with_partial_run(self, mock_snapshot): mock_snapshot.assert_called_once() self.assertEqual(expected_pipeline, pipeline_state.pipeline) + @parameterized.named_parameters( + dict( + testcase_name='cache_subpipeline', + run_subpipeline=False, + ), + dict( + testcase_name='run_subpipeline', + run_subpipeline=True, + ), + ) + @mock.patch.object(partial_run_utils, 'snapshot') + def test_initiate_pipeline_start_with_partial_run_and_subpipeline( + self, mock_snapshot, run_subpipeline + ): + with self._mlmd_connection as m: + pipeline = test_sync_pipeline.create_pipeline_with_subpipeline() + runtime_parameter_utils.substitute_runtime_parameter( + pipeline, + { + constants.PIPELINE_ROOT_PARAMETER_NAME: '/my/pipeline/root', + constants.PIPELINE_RUN_ID_PARAMETER_NAME: 'run-0123', + }, + ) + + expected_pipeline = copy.deepcopy(pipeline) + example_gen = expected_pipeline.nodes[0].pipeline_node + subpipeline = expected_pipeline.nodes[1].sub_pipeline + subpipeline_begin = subpipeline.nodes[0].pipeline_node + transform = expected_pipeline.nodes[2].pipeline_node + partial_run_utils.set_latest_pipeline_run_strategy( + expected_pipeline.runtime_spec.snapshot_settings + ) + + skip = pipeline_pb2.NodeExecutionOptions.Skip( + reuse_artifacts_mode=pipeline_pb2.NodeExecutionOptions.Skip.REQUIRED + ) + run = pipeline_pb2.NodeExecutionOptions.Run( + perform_snapshot=True, depends_on_snapshot=True + ) + example_gen.execution_options.skip.CopyFrom(skip) + + if run_subpipeline: + subpipeline_begin.execution_options.run.CopyFrom(run) + transform.execution_options.run.depends_on_snapshot = True + else: + subpipeline_begin.execution_options.skip.CopyFrom(skip) + transform.execution_options.run.CopyFrom(run) + + partial_run_option = pipeline_pb2.PartialRun( + from_nodes=['sub-pipeline'] if run_subpipeline else ['my_transform'], + snapshot_settings=partial_run_utils.latest_pipeline_snapshot_settings(), + ) + with pipeline_ops.initiate_pipeline_start( + m, pipeline, partial_run_option=partial_run_option + ) as pipeline_state: + mock_snapshot.assert_called_once() + self.assertProtoEquals(expected_pipeline, pipeline_state.pipeline) + + if run_subpipeline: + # If the subpipeline should be run then we should not have pre-loaded a + # run for it. + with self.assertRaises(status_lib.StatusNotOkError): + pstate.PipelineState.load_run( + m, 'sub-pipeline', 'sub-pipeline_run-0123' + ) + else: + # Skipped subpipelines should have a run injected so their nodes are + # properly marked as cached. + with pstate.PipelineState.load_run( + m, 'sub-pipeline', 'sub-pipeline_run-0123' + ) as subpipeline_state: + self.assertEqual( + subpipeline_state.stop_initiated_reason().code, status_lib.Code.OK + ) + @mock.patch.object(partial_run_utils, 'snapshot') def test_partial_run_with_previously_failed_nodes(self, mock_snapshot): with self._mlmd_connection as m: diff --git a/tfx/orchestration/portable/partial_run_utils.py b/tfx/orchestration/portable/partial_run_utils.py index 1d50f3a50c..2c7b33d088 100644 --- a/tfx/orchestration/portable/partial_run_utils.py +++ b/tfx/orchestration/portable/partial_run_utils.py @@ -153,7 +153,7 @@ def snapshot(mlmd_handle: metadata.Metadata, """ # Avoid unnecessary snapshotting step if no node needs to reuse any artifacts. if not any( - _should_attempt_to_reuse_artifact(node.pipeline_node.execution_options) + should_attempt_to_reuse_artifact(node.pipeline_node.execution_options) for node in pipeline.nodes): return @@ -452,8 +452,9 @@ def _get_validated_new_run_id(pipeline: pipeline_pb2.Pipeline, return str(inferred_new_run_id or new_run_id) -def _should_attempt_to_reuse_artifact( +def should_attempt_to_reuse_artifact( execution_options: pipeline_pb2.NodeExecutionOptions): + """Returns whether artifacts should be reused for the these execution options.""" return execution_options.HasField('skip') and ( execution_options.skip.reuse_artifacts or execution_options.skip.reuse_artifacts_mode == _REUSE_ARTIFACT_OPTIONAL or @@ -512,7 +513,7 @@ def _reuse_pipeline_run_artifacts( reuse_nodes = [ node for node in node_proto_view.get_view_for_all_in(marked_pipeline) - if _should_attempt_to_reuse_artifact(node.execution_options) + if should_attempt_to_reuse_artifact(node.execution_options) ] logging.info( 'Reusing nodes: %s', [n.node_info.id for n in reuse_nodes]