Skip to content

Commit

Permalink
Fix issue where subpipelines may get stuck due to insufficient task s…
Browse files Browse the repository at this point in the history
…chedulers by raising an error when the total number of subpipelines is greater than the maximum allowable task schedulers.

PiperOrigin-RevId: 660011775
  • Loading branch information
kmonte authored and tfx-copybara committed Aug 7, 2024
1 parent 5e90c67 commit aa84350
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 12 deletions.
7 changes: 7 additions & 0 deletions tfx/orchestration/experimental/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ def get_status_code_from_exception(
Returns None if the exception is not a known type.
"""

@abc.abstractmethod
def maximum_concurrent_task_schedulers(self) -> int:
"""Returns the maximum number of concurrent task schedulers."""


class _DefaultEnv(Env):
"""Default environment."""
Expand Down Expand Up @@ -244,6 +248,9 @@ def get_status_code_from_exception(
) -> Optional[int]:
return None

def maximum_concurrent_task_schedulers(self) -> int:
return 1


_ENV = _DefaultEnv()

Expand Down
3 changes: 3 additions & 0 deletions tfx/orchestration/experimental/core/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def record_orchestration_time(self, pipeline_run_id: str) -> None:
def should_orchestrate(self, pipeline: pipeline_pb2.Pipeline) -> bool:
raise NotImplementedError()

def maximum_concurrent_task_schedulers(self) -> int:
raise NotImplementedError()


class EnvTest(test_utils.TfxTest):

Expand Down
29 changes: 29 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Pipeline state management functionality."""

import base64
import collections
import contextlib
import copy
import dataclasses
Expand Down Expand Up @@ -515,6 +516,34 @@ def new(
Raises:
status_lib.StatusNotOkError: If a pipeline with same UID already exists.
"""
num_subpipelines = 0
to_process = collections.deque([pipeline])
while to_process:
p = to_process.popleft()
for node in p.nodes:
if node.WhichOneof('node') == 'sub_pipeline':
num_subpipelines += 1
to_process.append(node.sub_pipeline)
# If the number of active task schedulers is less than the maximum number of
# active task schedulers, subpipelines may not work.
# This is because when scheduling the subpipeline, the start node
# and end node will be scheduled immediately, potentially causing contention
# where the end node is waiting on some intermediary node to finish, but the
# intermediary node cannot be scheduled as the end node is running.
# Note that this number is an overestimate - in reality if subpipelines are
# dependent on each other we may not need so many task schedulers.
max_task_schedulers = env.get_env().maximum_concurrent_task_schedulers()
if max_task_schedulers < num_subpipelines:
raise status_lib.StatusNotOkError(
code=status_lib.Code.FAILED_PRECONDITION,
message=(
f'The maximum number of task schedulers ({max_task_schedulers})'
f' is less than the number of subpipelines ({num_subpipelines}).'
' Please set the maximum number of task schedulers to at least'
f' {num_subpipelines} in'
' OrchestrationOptions.max_running_components.'
),
)
pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
context = context_lib.register_context_if_not_exists(
mlmd_handle,
Expand Down
59 changes: 47 additions & 12 deletions tfx/orchestration/experimental/core/pipeline_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import dataclasses
import os
import sys
import time
from typing import List
from unittest import mock
Expand All @@ -36,6 +37,7 @@
from tfx.proto.orchestration import run_state_pb2
from tfx.utils import json_utils
from tfx.utils import status as status_lib

import ml_metadata as mlmd
from ml_metadata.proto import metadata_store_pb2

Expand Down Expand Up @@ -155,9 +157,13 @@ def test_node_state_json(self):

class TestEnv(env._DefaultEnv):

def __init__(self, base_dir, max_str_len):
def __init__(self, base_dir, max_str_len, max_task_schedulers):
self.base_dir = base_dir
self.max_str_len = max_str_len
self.max_task_schedulers = max_task_schedulers

def maximum_concurrent_task_schedulers(self):
return self.max_task_schedulers

def get_base_dir(self):
return self.base_dir
Expand Down Expand Up @@ -216,7 +222,7 @@ def test_new_pipeline_state(self):
self.assertTrue(pstate._active_owned_pipelines_exist)

def test_new_pipeline_state_with_sub_pipelines(self):
with self._mlmd_connection as m:
with TestEnv(None, 20000, 2), self._mlmd_connection as m:
pstate._active_owned_pipelines_exist = False
pipeline = _test_pipeline('pipeline1')
# Add 2 additional layers of sub pipelines. Note that there is no normal
Expand Down Expand Up @@ -276,6 +282,33 @@ def test_new_pipeline_state_with_sub_pipelines(self):
],
)

def test_new_pipeline_state_with_sub_pipelines_fails_when_not_enough_task_schedulers(
self,
):
with TestEnv(None, 20000, 1), self._mlmd_connection as m:
pstate._active_owned_pipelines_exist = False
pipeline = _test_pipeline('pipeline1')
# Add 2 additional layers of sub pipelines. Note that there is no normal
# pipeline node in the first pipeline layer.
_add_sub_pipeline(
pipeline,
'sub_pipeline1',
sub_pipeline_nodes=['Trainer'],
sub_pipeline_run_id='sub_pipeline1_run0',
)
_add_sub_pipeline(
pipeline.nodes[0].sub_pipeline,
'sub_pipeline2',
sub_pipeline_nodes=['Trainer'],
sub_pipeline_run_id='sub_pipeline1_sub_pipeline2_run0',
)
with self.assertRaisesRegex(
status_lib.StatusNotOkError,
'The maximum number of task schedulers',
) as e:
pstate.PipelineState.new(m, pipeline)
self.assertEqual(e.exception.code, status_lib.Code.FAILED_PRECONDITION)

def test_load_pipeline_state(self):
with self._mlmd_connection as m:
pipeline = _test_pipeline('pipeline1', pipeline_nodes=['Trainer'])
Expand Down Expand Up @@ -770,7 +803,9 @@ def test_initiate_node_start_stop(self, mock_time):
def recorder(event):
events.append(event)

with TestEnv(None, 2000), event_observer.init(), self._mlmd_connection as m:
with TestEnv(
None, 2000, sys.maxsize
), event_observer.init(), self._mlmd_connection as m:
event_observer.register_observer(recorder)

pipeline = _test_pipeline('pipeline1', pipeline_nodes=['Trainer'])
Expand Down Expand Up @@ -900,7 +935,7 @@ def recorder(event):
@mock.patch.object(pstate, 'time')
def test_get_node_states_dict(self, mock_time):
mock_time.time.return_value = time.time()
with TestEnv(None, 20000), self._mlmd_connection as m:
with TestEnv(None, 20000, sys.maxsize), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
Expand Down Expand Up @@ -1120,7 +1155,7 @@ def test_pipeline_view_get_pipeline_run_state(self, mock_time):
@mock.patch.object(pstate, 'time')
def test_pipeline_view_get_node_run_states(self, mock_time):
mock_time.time.return_value = time.time()
with TestEnv(None, 20000), self._mlmd_connection as m:
with TestEnv(None, 20000, sys.maxsize), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
Expand Down Expand Up @@ -1205,7 +1240,7 @@ def test_pipeline_view_get_node_run_states(self, mock_time):
@mock.patch.object(pstate, 'time')
def test_pipeline_view_get_node_run_state_history(self, mock_time):
mock_time.time.return_value = time.time()
with TestEnv(None, 20000), self._mlmd_connection as m:
with TestEnv(None, 20000, sys.maxsize), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
Expand Down Expand Up @@ -1252,7 +1287,7 @@ def test_node_state_for_skipped_nodes_in_partial_pipeline_run(
):
"""Tests that nodes marked to be skipped have the right node state and previous node state."""
mock_time.time.return_value = time.time()
with TestEnv(None, 20000), self._mlmd_connection as m:
with TestEnv(None, 20000, sys.maxsize), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
Expand Down Expand Up @@ -1371,7 +1406,7 @@ def test_load_all_with_list_options(self):
def test_get_previous_node_run_states_for_skipped_nodes(self, mock_time):
"""Tests that nodes marked to be skipped have the right previous run state."""
mock_time.time.return_value = time.time()
with TestEnv(None, 20000), self._mlmd_connection as m:
with TestEnv(None, 20000, sys.maxsize), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
Expand Down Expand Up @@ -1498,7 +1533,7 @@ def test_create_and_load_concurrent_pipeline_runs(self):
)

def test_get_pipeline_and_node(self):
with TestEnv(None, 20000), self._mlmd_connection as m:
with TestEnv(None, 20000, sys.maxsize), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
Expand All @@ -1516,7 +1551,7 @@ def test_get_pipeline_and_node(self):
)

def test_get_pipeline_and_node_not_found(self):
with TestEnv(None, 20000), self._mlmd_connection as m:
with TestEnv(None, 20000, sys.maxsize), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
Expand Down Expand Up @@ -1594,7 +1629,7 @@ def test_save_with_max_str_len(self):
state=pstate.NodeState.COMPLETE,
)
}
with TestEnv(None, 20):
with TestEnv(None, 20, sys.maxsize):
execution = metadata_store_pb2.Execution()
proxy = pstate._NodeStatesProxy(execution)
proxy.set(node_states)
Expand All @@ -1605,7 +1640,7 @@ def test_save_with_max_str_len(self):
),
json_utils.dumps(node_states_without_state_history),
)
with TestEnv(None, 2000):
with TestEnv(None, 2000, sys.maxsize):
execution = metadata_store_pb2.Execution()
proxy = pstate._NodeStatesProxy(execution)
proxy.set(node_states)
Expand Down

0 comments on commit aa84350

Please sign in to comment.