Skip to content

Commit

Permalink
Orchestrator shouldn't crash when MLMD call fails
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 641385227
  • Loading branch information
tfx-copybara committed Jul 2, 2024
1 parent 4e71a35 commit 4c47271
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 3 deletions.
19 changes: 19 additions & 0 deletions tfx/orchestration/experimental/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,20 @@ def should_orchestrate(self, pipeline: pipeline_pb2.Pipeline) -> bool:
Whether the env should orchestrate the pipeline.
"""

@abc.abstractmethod
def get_status_code_from_exception(
self, exception: Optional[BaseException]
) -> Optional[int]:
"""Returns the status code from the given exception.
Args:
exception: An exception.
Returns:
Code of the exception.
Returns None if the exception is not a known type.
"""


class _DefaultEnv(Env):
"""Default environment."""
Expand Down Expand Up @@ -211,6 +225,11 @@ def should_orchestrate(self, pipeline: pipeline_pb2.Pipeline) -> bool:
# By default, all pipeline runs should be orchestrated.
return True

def get_status_code_from_exception(
self, exception: Optional[BaseException]
) -> Optional[int]:
return None


_ENV = _DefaultEnv()

Expand Down
5 changes: 5 additions & 0 deletions tfx/orchestration/experimental/core/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def prepare_orchestrator_for_pipeline_run(
):
raise NotImplementedError()

def get_status_code_from_exception(
self, exception: Optional[BaseException]
) -> Optional[int]:
raise NotImplementedError()

def create_sync_or_upsert_async_pipeline_run(
self,
owner: str,
Expand Down
19 changes: 16 additions & 3 deletions tfx/orchestration/experimental/core/pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,9 +1298,22 @@ def orchestrate(
if filter_fn is None:
filter_fn = lambda _: True

all_pipeline_states = pstate.PipelineState.load_all_active_and_owned(
mlmd_connection_manager.primary_mlmd_handle
)
# Try to load active pipelines. If there is a recoverable error, return False
# and then retry in the next orchestration iteration.
try:
all_pipeline_states = pstate.PipelineState.load_all_active_and_owned(
mlmd_connection_manager.primary_mlmd_handle
)
except Exception as e: # pylint: disable=broad-except
code = env.get_env().get_status_code_from_exception(e)
if code in status_lib.BATCH_RETRIABLE_ERROR_CODES:
logging.exception(
'Failed to load active pipeline states. Will retry in next'
' orchestration iteration.',
)
return True
raise e

pipeline_states = [s for s in all_pipeline_states if filter_fn(s)]
if not pipeline_states:
logging.info('No active pipelines to run.')
Expand Down
78 changes: 78 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from tfx.types import standard_artifacts
from tfx.utils import status as status_lib

from ml_metadata import errors as mlmd_errors
from ml_metadata.proto import metadata_store_pb2


Expand Down Expand Up @@ -3589,6 +3590,83 @@ def test_orchestrate_pipelines_with_specified_pipeline_uid(
)
self.assertTrue(task_queue.is_empty())

@parameterized.parameters(
(mlmd_errors.DeadlineExceededError('DeadlineExceededError'), 4),
(mlmd_errors.InternalError('InternalError'), 13),
(mlmd_errors.UnavailableError('UnavailableError'), 14),
(mlmd_errors.ResourceExhaustedError('ResourceExhaustedError'), 8),
(
status_lib.StatusNotOkError(
code=status_lib.Code.DEADLINE_EXCEEDED,
message='DeadlineExceededError',
),
4,
),
(
status_lib.StatusNotOkError(
code=status_lib.Code.INTERNAL, message='InternalError'
),
13,
),
(
status_lib.StatusNotOkError(
code=status_lib.Code.UNAVAILABLE, message='UnavailableError'
),
14,
),
(
status_lib.StatusNotOkError(
code=status_lib.Code.RESOURCE_EXHAUSTED,
message='ResourceExhaustedError',
),
8,
),
)
@mock.patch.object(pstate.PipelineState, 'load_all_active_and_owned')
def test_orchestrate_pipelines_with_recoverable_error_from_MLMD(
self,
error,
error_code,
mock_load_all_active_and_owned,
):
mock_load_all_active_and_owned.side_effect = error

with test_utils.get_status_code_from_exception_environment(error_code):
with self._mlmd_cm as mlmd_connection_manager:
task_queue = tq.TaskQueue()
orchestrate_result = pipeline_ops.orchestrate(
mlmd_connection_manager,
task_queue,
service_jobs.DummyServiceJobManager(),
)
self.assertEqual(orchestrate_result, True)

@parameterized.parameters(
mlmd_errors.InvalidArgumentError('InvalidArgumentError'),
mlmd_errors.FailedPreconditionError('FailedPreconditionError'),
status_lib.StatusNotOkError(
code=status_lib.Code.INVALID_ARGUMENT, message='InvalidArgumentError'
),
status_lib.StatusNotOkError(
code=status_lib.Code.UNKNOWN,
message='UNKNOWN',
),
)
@mock.patch.object(pstate.PipelineState, 'load_all_active_and_owned')
def test_orchestrate_pipelines_with_not_recoverable_error_from_MLMD(
self, error, mock_load_all_active_and_owned
):
mock_load_all_active_and_owned.side_effect = error

with self._mlmd_cm as mlmd_connection_manager:
task_queue = tq.TaskQueue()
with self.assertRaises(Exception):
pipeline_ops.orchestrate(
mlmd_connection_manager,
task_queue,
service_jobs.DummyServiceJobManager(),
)


if __name__ == '__main__':
tf.test.main()
20 changes: 20 additions & 0 deletions tfx/utils/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,26 @@ class Code(enum.IntEnum):
UNAUTHENTICATED = 16


# These are the error codes that are retriable for USER_FACING traffic.
# See go/stubs-retries.
USER_FACING_RETRIABLE_STATUS_CODES = frozenset(
c.value
for c in [
Code.UNAVAILABLE,
]
)

BATCH_RETRIABLE_ERROR_CODES = frozenset(
c.value
for c in [
Code.DEADLINE_EXCEEDED,
Code.INTERNAL,
Code.UNAVAILABLE,
Code.RESOURCE_EXHAUSTED,
]
)


@attr.s(auto_attribs=True, frozen=True)
class Status:
"""Class to record status of operations.
Expand Down

0 comments on commit 4c47271

Please sign in to comment.