From 88d7f1841cb1bcf5f78cc1f6db87280c3d23d9a4 Mon Sep 17 00:00:00 2001 From: tfx-team Date: Fri, 7 Jun 2024 16:08:16 -0700 Subject: [PATCH] Orchestrator shouldn't crash when MLMD call fails PiperOrigin-RevId: 641385227 --- .../experimental/core/pipeline_ops.py | 12 +++- .../experimental/core/pipeline_ops_test.py | 61 +++++++++++++++++++ 2 files changed, 70 insertions(+), 3 deletions(-) diff --git a/tfx/orchestration/experimental/core/pipeline_ops.py b/tfx/orchestration/experimental/core/pipeline_ops.py index 76188665f9e..0e5ed9fc564 100644 --- a/tfx/orchestration/experimental/core/pipeline_ops.py +++ b/tfx/orchestration/experimental/core/pipeline_ops.py @@ -1298,9 +1298,15 @@ def orchestrate( if filter_fn is None: filter_fn = lambda _: True - all_pipeline_states = pstate.PipelineState.load_all_active_and_owned( - mlmd_connection_manager.primary_mlmd_handle - ) + # Try to load active pipelines. If there is a recoverable error, return False + # and then retry in the next orchestration iteration. + try: + all_pipeline_states = pstate.PipelineState.load_all_active_and_owned( + mlmd_connection_manager.primary_mlmd_handle + ) + except Exception as e: # pylint: disable=broad-except + raise e + pipeline_states = [s for s in all_pipeline_states if filter_fn(s)] if not pipeline_states: logging.info('No active pipelines to run.') diff --git a/tfx/orchestration/experimental/core/pipeline_ops_test.py b/tfx/orchestration/experimental/core/pipeline_ops_test.py index 56bb1151871..da2f1950573 100644 --- a/tfx/orchestration/experimental/core/pipeline_ops_test.py +++ b/tfx/orchestration/experimental/core/pipeline_ops_test.py @@ -54,6 +54,7 @@ from tfx.types import standard_artifacts from tfx.utils import status as status_lib +from ml_metadata import errors as mlmd_errors from ml_metadata.proto import metadata_store_pb2 @@ -3589,6 +3590,66 @@ def test_orchestrate_pipelines_with_specified_pipeline_uid( ) self.assertTrue(task_queue.is_empty()) + @parameterized.parameters( + mlmd_errors.DeadlineExceededError('DeadlineExceededError'), + mlmd_errors.InternalError('InternalError'), + mlmd_errors.UnavailableError('UnavailableError'), + mlmd_errors.ResourceExhaustedError('ResourceExhaustedError'), + status_lib.StatusNotOkError( + code=status_lib.Code.DEADLINE_EXCEEDED, + message='DeadlineExceededError', + ), + status_lib.StatusNotOkError( + code=status_lib.Code.INTERNAL, message='InternalError' + ), + status_lib.StatusNotOkError( + code=status_lib.Code.UNAVAILABLE, message='UnavailableError' + ), + status_lib.StatusNotOkError( + code=status_lib.Code.RESOURCE_EXHAUSTED, + message='ResourceExhaustedError', + ), + ) + @mock.patch.object(pstate.PipelineState, 'load_all_active_and_owned') + def test_orchestrate_pipelines_with_recoverable_error_from_MLMD( + self, error, mock_load_all_active_and_owned + ): + mock_load_all_active_and_owned.side_effect = error + + with self._mlmd_cm as mlmd_connection_manager: + task_queue = tq.TaskQueue() + orchestrate_result = pipeline_ops.orchestrate( + mlmd_connection_manager, + task_queue, + service_jobs.DummyServiceJobManager(), + ) + self.assertEqual(orchestrate_result, False) + + @parameterized.parameters( + mlmd_errors.InvalidArgumentError('InvalidArgumentError'), + mlmd_errors.FailedPreconditionError('FailedPreconditionError'), + status_lib.StatusNotOkError( + code=status_lib.Code.INVALID_ARGUMENT, message='InvalidArgumentError' + ), + status_lib.StatusNotOkError( + code=status_lib.Code.UNKNOWN, + message='UNKNOWN', + ), + ) + @mock.patch.object(pstate.PipelineState, 'load_all_active_and_owned') + def test_orchestrate_pipelines_with_not_recoverable_error_from_MLMD( + self, error, mock_load_all_active_and_owned + ): + mock_load_all_active_and_owned.side_effect = error + + with self._mlmd_cm as mlmd_connection_manager: + task_queue = tq.TaskQueue() + with self.assertRaises(Exception): + pipeline_ops.orchestrate( + mlmd_connection_manager, + task_queue, + service_jobs.DummyServiceJobManager(), + ) if __name__ == '__main__': tf.test.main()