diff --git a/tfx/orchestration/experimental/core/pipeline_ops.py b/tfx/orchestration/experimental/core/pipeline_ops.py index 19a4bba68b..76188665f9 100644 --- a/tfx/orchestration/experimental/core/pipeline_ops.py +++ b/tfx/orchestration/experimental/core/pipeline_ops.py @@ -33,6 +33,7 @@ from tfx.dsl.io import filesystem from tfx.orchestration import metadata from tfx.orchestration import node_proto_view +from tfx.orchestration import subpipeline_utils from tfx.orchestration.experimental.core import async_pipeline_task_gen from tfx.orchestration.experimental.core import constants from tfx.orchestration.experimental.core import env @@ -1252,6 +1253,20 @@ def filter_by_pipeline_uid( return lambda p: p.pipeline_uid == pipeline_uid +def _record_orchestration_time(pipeline_state: pstate.PipelineState) -> None: + """Records an orchestration time for the pipeline run.""" + # We only care about orchestration time for root pipelines, skip any + # subpipelines. + if subpipeline_utils.is_subpipeline(pipeline_state.pipeline): + return + pipeline_run_id = pipeline_state.pipeline_run_id + # Backend expects an empty string for the pipeline run id, for ASYNC pipeline + # runs. + if pipeline_run_id is None: + pipeline_run_id = '' + env.get_env().record_orchestration_time(pipeline_run_id) + + @_pipeline_op() def orchestrate( mlmd_connection_manager: mlmd_cm.MLMDConnectionManager, @@ -1322,6 +1337,7 @@ def orchestrate( service_job_manager, pipeline_state, ) + _record_orchestration_time(pipeline_state) except Exception: # pylint: disable=broad-except # If orchestrating a stop-initiated pipeline raises an exception, we log # the exception but do not re-raise since we do not want to crash the @@ -1345,6 +1361,7 @@ def orchestrate( service_job_manager, pipeline_state, ) + _record_orchestration_time(pipeline_state) except Exception as e: # pylint: disable=broad-except logging.exception( 'Exception raised while orchestrating update-initiated pipeline %s', @@ -1364,6 +1381,7 @@ def orchestrate( ), ) ) + _record_orchestration_time(pipeline_state) except Exception: # pylint: disable=broad-except # If stop initiation also raised an exception , we log the exception but # do not re-raise since we do not want to crash the orchestrator. If @@ -1387,6 +1405,7 @@ def orchestrate( service_job_manager, pipeline_state, ) + _record_orchestration_time(pipeline_state) except Exception as e: # pylint: disable=broad-except logging.exception( 'Exception raised while orchestrating active pipeline %s', @@ -1404,6 +1423,7 @@ def orchestrate( message=f'Error orchestrating active pipeline: {str(e)}', ) ) + _record_orchestration_time(pipeline_state) except Exception: # pylint: disable=broad-except # If stop initiation also raised an exception , we log the exception but # do not re-raise since we do not want to crash the orchestrator. If diff --git a/tfx/orchestration/experimental/core/pipeline_ops_test.py b/tfx/orchestration/experimental/core/pipeline_ops_test.py index da4574146e..56bb115187 100644 --- a/tfx/orchestration/experimental/core/pipeline_ops_test.py +++ b/tfx/orchestration/experimental/core/pipeline_ops_test.py @@ -1434,10 +1434,73 @@ def test_stop_node_wait_for_inactivation_timeout(self): (pstate.NodeState.STOPPING, pstate.NodeState.STOPPED), ) + @parameterized.named_parameters( + dict( + testcase_name='async', + pipeline=_test_pipeline('pipeline1'), + expected_run_id='', + ), + dict( + testcase_name='sync', + pipeline=_test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC), + expected_run_id='run0', + ), + ) + def test_record_orchestration_time(self, pipeline, expected_run_id): + with self._mlmd_cm as mlmd_connection_manager: + m = mlmd_connection_manager.primary_mlmd_handle + pipeline_ops.initiate_pipeline_start(m, pipeline) + environment = env.get_env() + with mock.patch.object( + environment, + 'record_orchestration_time', + wraps=environment.record_orchestration_time, + ) as mock_env_record_orchestration_time: + task_queue = tq.TaskQueue() + pipeline_ops.orchestrate( + mlmd_connection_manager, + task_queue, + self._mock_service_job_manager, + ) + mock_env_record_orchestration_time.assert_called_with(expected_run_id) + + def test_record_orchestration_time_subpipeline(self): + with self._mlmd_cm as mlmd_connection_manager: + m = mlmd_connection_manager.primary_mlmd_handle + pipeline = test_sync_pipeline.create_pipeline_with_subpipeline() + runtime_parameter_utils.substitute_runtime_parameter( + pipeline, + { + constants.PIPELINE_RUN_ID_PARAMETER_NAME: 'run0', + }, + ) + pipeline_ops.initiate_pipeline_start(m, pipeline) + environment = env.get_env() + with mock.patch.object( + environment, + 'record_orchestration_time', + wraps=environment.record_orchestration_time, + ) as mock_env_record_orchestration_time: + task_queue = tq.TaskQueue() + pipeline_ops.orchestrate( + mlmd_connection_manager, + task_queue, + self._mock_service_job_manager, + ) + mock_env_record_orchestration_time.assert_called_with('run0') + @mock.patch.object(sync_pipeline_task_gen, 'SyncPipelineTaskGenerator') @mock.patch.object(async_pipeline_task_gen, 'AsyncPipelineTaskGenerator') + @mock.patch.object( + pipeline_ops, + '_record_orchestration_time', + wraps=pipeline_ops._record_orchestration_time, + ) def test_orchestrate_active_pipelines( - self, mock_async_task_gen, mock_sync_task_gen + self, + mock_record_orchestration_time, + mock_async_task_gen, + mock_sync_task_gen, ): with self._mlmd_cm as mlmd_connection_manager: m = mlmd_connection_manager.primary_mlmd_handle @@ -1509,6 +1572,15 @@ def test_orchestrate_active_pipelines( service_jobs.DummyServiceJobManager(), ) + # Check that the orchestration time was recorded four times. Once for each + # of the four pipelines. + mock_record_orchestration_time.assert_has_calls([ + mock.call(mock.ANY), + mock.call(mock.ANY), + mock.call(mock.ANY), + mock.call(mock.ANY), + ]) + self.assertEqual(2, mock_async_task_gen.return_value.generate.call_count) self.assertEqual(2, mock_sync_task_gen.return_value.generate.call_count) @@ -1550,9 +1622,15 @@ def test_orchestrate_active_pipelines( @mock.patch.object( task_gen_utils, 'generate_cancel_task_from_running_execution' ) + @mock.patch.object( + pipeline_ops, + '_record_orchestration_time', + wraps=pipeline_ops._record_orchestration_time, + ) def test_orchestrate_stop_initiated_pipelines( self, pipeline, + mock_record_orchestration_time, mock_gen_task_from_active, mock_async_task_gen, mock_sync_task_gen, @@ -1617,6 +1695,10 @@ def recorder(event): self._mock_service_job_manager, ) ) + # We should have recorded the orchestration time once, for one pipeline. + # We reset after to verify this is true throughout. + mock_record_orchestration_time.assert_called_once() + mock_record_orchestration_time.reset_mock() # PipelineFinished event should not trigger since not all the nodes are # stopped. @@ -1683,6 +1765,8 @@ def recorder(event): self._mock_service_job_manager, ) ) + mock_record_orchestration_time.assert_called_once() + mock_record_orchestration_time.reset_mock() self.assertTrue(task_queue.is_empty()) [execution] = m.store.get_executions_by_id([pipeline_execution_id]) self.assertEqual( @@ -1721,6 +1805,7 @@ def recorder(event): self._mock_service_job_manager, ) ) + mock_record_orchestration_time.assert_not_called() @mock.patch.object( task_gen_utils, 'generate_cancel_task_from_running_execution' @@ -1890,7 +1975,14 @@ def recorder(event): _test_pipeline('pipeline1'), _test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC), ) - def test_orchestrate_update_initiated_pipelines(self, pipeline): + @mock.patch.object( + pipeline_ops, + '_record_orchestration_time', + wraps=pipeline_ops._record_orchestration_time, + ) + def test_orchestrate_update_initiated_pipelines( + self, pipeline, mock_record_orchestration_time + ): with self._mlmd_cm as mlmd_connection_manager: m = mlmd_connection_manager.primary_mlmd_handle pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen' @@ -1924,6 +2016,10 @@ def test_orchestrate_update_initiated_pipelines(self, pipeline): pipeline_ops.orchestrate( mlmd_connection_manager, task_queue, self._mock_service_job_manager ) + # We should have recorded the orchestration time once, for one pipeline. + # We reset after to verify this is true throughout. + mock_record_orchestration_time.assert_called_once() + mock_record_orchestration_time.reset_mock() # stop_node_services should be called for ExampleGen. self._mock_service_job_manager.stop_node_services.assert_has_calls( [mock.call(mock.ANY, 'ExampleGen')] @@ -1954,6 +2050,9 @@ def test_orchestrate_update_initiated_pipelines(self, pipeline): self._mock_service_job_manager.stop_node_services.assert_has_calls( [mock.call(mock.ANY, 'Transform')] ) + # Check that the orchestration time was recorded again. + mock_record_orchestration_time.assert_called_once() + mock_record_orchestration_time.reset_mock() # Check that the node states are STARTING. [execution] = m.store.get_executions_by_id([pipeline_state.execution_id])