Skip to content

Commit

Permalink
Add note and warning about subpipelines maybe not working properly wh…
Browse files Browse the repository at this point in the history
…en max_active_task_schedulers is set to one

PiperOrigin-RevId: 660011775
  • Loading branch information
kmonte authored and tfx-copybara committed Aug 6, 2024
1 parent 5e90c67 commit 21f0d53
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 27 deletions.
5 changes: 4 additions & 1 deletion tfx/orchestration/experimental/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""For environment specific extensions."""

import abc
import sys
from typing import Optional, Sequence

from tfx.orchestration.experimental.core import orchestration_options
Expand Down Expand Up @@ -166,7 +167,9 @@ def get_orchestration_options(
self, pipeline: pipeline_pb2.Pipeline
) -> orchestration_options.OrchestrationOptions:
del pipeline
return orchestration_options.OrchestrationOptions()
return orchestration_options.OrchestrationOptions(
max_running_task_schedulers=sys.maxsize
)

def label_and_tag_pipeline_run(
self, mlmd_handle, pipeline_id, pipeline_run_id, labels, tags
Expand Down
5 changes: 5 additions & 0 deletions tfx/orchestration/experimental/core/orchestration_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Orchestration options."""

import sys
import attr


Expand All @@ -27,6 +28,10 @@ class OrchestrationOptions:
failures.
deadline_secs: Only applicable to sync pipelines. If non-zero, a pipeline
run is aborted if the execution duration exceeds deadline_secs seconds.
max_running_task_schedulers: The total number of task schedulers that may be
running at a time. Note this is a GLOBAL limit across all concurrent runs,
subpipeline runs, etc for a given orchestrator.
"""
fail_fast: bool = False
deadline_secs: int = 0
max_running_task_schedulers: int = sys.maxsize
33 changes: 33 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Pipeline state management functionality."""

import base64
import collections
import contextlib
import copy
import dataclasses
Expand Down Expand Up @@ -515,6 +516,38 @@ def new(
Raises:
status_lib.StatusNotOkError: If a pipeline with same UID already exists.
"""
num_subpipelines = 0
to_process = collections.deque([pipeline])
while to_process:
p = to_process.popleft()
for node in p.nodes:
if node.WhichOneof('node') == 'sub_pipeline':
num_subpipelines += 1
to_process.append(node.sub_pipeline)
# If the number of active task schedulers is less than the maximum number of
# active task schedulers, subpipelines may not work.
# This is because when scheduling the subpipeline, the start node
# and end node will be scheduled immediately, potentially causing contention
# where the end node is waiting on some intermediary node to finish, but the
# intermediary node cannot be scheduled as the end node is running.
# Note that this number is an upper bound - in reality if subpipelines are
# dependent on each other the limit will be lower.
max_task_schedulers = (
env.get_env()
.get_orchestration_options(pipeline)
.max_running_task_schedulers
)
if max_task_schedulers < num_subpipelines:
raise status_lib.StatusNotOkError(
code=status_lib.Code.FAILED_PRECONDITION,
message=(
f'The maxmimum number of task schedulers ({max_task_schedulers})'
f' is less than the number of subpipelines ({num_subpipelines}).'
' Please set the maximum number of task schedulers to at least'
f' {num_subpipelines} in'
' OrchestrationOptions.max_running_components.'
),
)
pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
context = context_lib.register_context_if_not_exists(
mlmd_handle,
Expand Down
65 changes: 54 additions & 11 deletions tfx/orchestration/experimental/core/pipeline_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import dataclasses
import os
import sys
import time
from typing import List
from unittest import mock
Expand All @@ -26,6 +27,7 @@
from tfx.orchestration import metadata
from tfx.orchestration.experimental.core import env
from tfx.orchestration.experimental.core import event_observer
from tfx.orchestration.experimental.core import orchestration_options
from tfx.orchestration.experimental.core import pipeline_state as pstate
from tfx.orchestration.experimental.core import task as task_lib
from tfx.orchestration.experimental.core import task_gen_utils
Expand All @@ -36,6 +38,7 @@
from tfx.proto.orchestration import run_state_pb2
from tfx.utils import json_utils
from tfx.utils import status as status_lib

import ml_metadata as mlmd
from ml_metadata.proto import metadata_store_pb2

Expand Down Expand Up @@ -155,9 +158,20 @@ def test_node_state_json(self):

class TestEnv(env._DefaultEnv):

def __init__(self, base_dir, max_str_len):
def __init__(self, base_dir, max_str_len, max_task_schedulers):
self.base_dir = base_dir
self.max_str_len = max_str_len
self.max_task_schedulers = max_task_schedulers

def get_orchestration_options(
self, pipeline: pipeline_pb2.Pipeline
) -> orchestration_options.OrchestrationOptions:
super_options = super().get_orchestration_options(pipeline)
return orchestration_options.OrchestrationOptions(
fail_fast=super_options.fail_fast,
deadline_secs=super_options.deadline_secs,
max_running_task_schedulers=self.max_task_schedulers,
)

def get_base_dir(self):
return self.base_dir
Expand Down Expand Up @@ -276,6 +290,33 @@ def test_new_pipeline_state_with_sub_pipelines(self):
],
)

def test_new_pipeline_state_with_sub_pipelines_fails_when_not_enough_task_schedulers(
self,
):
with TestEnv(None, 20000, 1), self._mlmd_connection as m:
pstate._active_owned_pipelines_exist = False
pipeline = _test_pipeline('pipeline1')
# Add 2 additional layers of sub pipelines. Note that there is no normal
# pipeline node in the first pipeline layer.
_add_sub_pipeline(
pipeline,
'sub_pipeline1',
sub_pipeline_nodes=['Trainer'],
sub_pipeline_run_id='sub_pipeline1_run0',
)
_add_sub_pipeline(
pipeline.nodes[0].sub_pipeline,
'sub_pipeline2',
sub_pipeline_nodes=['Trainer'],
sub_pipeline_run_id='sub_pipeline1_sub_pipeline2_run0',
)
with self.assertRaisesRegex(
status_lib.StatusNotOkError,
'The maxmimum number of task schedulers',
) as e:
pstate.PipelineState.new(m, pipeline)
self.assertEqual(e.exception.code, status_lib.Code.FAILED_PRECONDITION)

def test_load_pipeline_state(self):
with self._mlmd_connection as m:
pipeline = _test_pipeline('pipeline1', pipeline_nodes=['Trainer'])
Expand Down Expand Up @@ -770,7 +811,9 @@ def test_initiate_node_start_stop(self, mock_time):
def recorder(event):
events.append(event)

with TestEnv(None, 2000), event_observer.init(), self._mlmd_connection as m:
with TestEnv(
None, 2000, sys.maxsize
), event_observer.init(), self._mlmd_connection as m:
event_observer.register_observer(recorder)

pipeline = _test_pipeline('pipeline1', pipeline_nodes=['Trainer'])
Expand Down Expand Up @@ -900,7 +943,7 @@ def recorder(event):
@mock.patch.object(pstate, 'time')
def test_get_node_states_dict(self, mock_time):
mock_time.time.return_value = time.time()
with TestEnv(None, 20000), self._mlmd_connection as m:
with TestEnv(None, 20000, sys.maxsize), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
Expand Down Expand Up @@ -1120,7 +1163,7 @@ def test_pipeline_view_get_pipeline_run_state(self, mock_time):
@mock.patch.object(pstate, 'time')
def test_pipeline_view_get_node_run_states(self, mock_time):
mock_time.time.return_value = time.time()
with TestEnv(None, 20000), self._mlmd_connection as m:
with TestEnv(None, 20000, sys.maxsize), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
Expand Down Expand Up @@ -1205,7 +1248,7 @@ def test_pipeline_view_get_node_run_states(self, mock_time):
@mock.patch.object(pstate, 'time')
def test_pipeline_view_get_node_run_state_history(self, mock_time):
mock_time.time.return_value = time.time()
with TestEnv(None, 20000), self._mlmd_connection as m:
with TestEnv(None, 20000, sys.maxsize), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
Expand Down Expand Up @@ -1252,7 +1295,7 @@ def test_node_state_for_skipped_nodes_in_partial_pipeline_run(
):
"""Tests that nodes marked to be skipped have the right node state and previous node state."""
mock_time.time.return_value = time.time()
with TestEnv(None, 20000), self._mlmd_connection as m:
with TestEnv(None, 20000, sys.maxsize), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
Expand Down Expand Up @@ -1371,7 +1414,7 @@ def test_load_all_with_list_options(self):
def test_get_previous_node_run_states_for_skipped_nodes(self, mock_time):
"""Tests that nodes marked to be skipped have the right previous run state."""
mock_time.time.return_value = time.time()
with TestEnv(None, 20000), self._mlmd_connection as m:
with TestEnv(None, 20000, sys.maxsize), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
Expand Down Expand Up @@ -1498,7 +1541,7 @@ def test_create_and_load_concurrent_pipeline_runs(self):
)

def test_get_pipeline_and_node(self):
with TestEnv(None, 20000), self._mlmd_connection as m:
with TestEnv(None, 20000, sys.maxsize), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
Expand All @@ -1516,7 +1559,7 @@ def test_get_pipeline_and_node(self):
)

def test_get_pipeline_and_node_not_found(self):
with TestEnv(None, 20000), self._mlmd_connection as m:
with TestEnv(None, 20000, sys.maxsize), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
Expand Down Expand Up @@ -1594,7 +1637,7 @@ def test_save_with_max_str_len(self):
state=pstate.NodeState.COMPLETE,
)
}
with TestEnv(None, 20):
with TestEnv(None, 20, sys.maxsize):
execution = metadata_store_pb2.Execution()
proxy = pstate._NodeStatesProxy(execution)
proxy.set(node_states)
Expand All @@ -1605,7 +1648,7 @@ def test_save_with_max_str_len(self):
),
json_utils.dumps(node_states_without_state_history),
)
with TestEnv(None, 2000):
with TestEnv(None, 2000, sys.maxsize):
execution = metadata_store_pb2.Execution()
proxy = pstate._NodeStatesProxy(execution)
proxy.set(node_states)
Expand Down
40 changes: 25 additions & 15 deletions tfx/orchestration/experimental/core/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,14 @@ class TaskManager:
TaskManager instance can be used as a context manager:
"""

def __init__(self,
mlmd_handle: metadata.Metadata,
task_queue: tq.TaskQueue,
max_active_task_schedulers: int,
max_dequeue_wait_secs: float = _MAX_DEQUEUE_WAIT_SECS,
process_all_queued_tasks_before_exit: bool = False):
def __init__(
self,
mlmd_handle: metadata.Metadata,
task_queue: tq.TaskQueue,
max_active_task_schedulers: int,
max_dequeue_wait_secs: float = _MAX_DEQUEUE_WAIT_SECS,
process_all_queued_tasks_before_exit: bool = False,
):
"""Constructs `TaskManager`.
Args:
Expand All @@ -160,7 +162,8 @@ def __init__(self,
self._task_queue = task_queue
self._max_dequeue_wait_secs = max_dequeue_wait_secs
self._process_all_queued_tasks_before_exit = (
process_all_queued_tasks_before_exit)
process_all_queued_tasks_before_exit
)

self._tm_lock = threading.Lock()
self._stop_event = threading.Event()
Expand Down Expand Up @@ -216,8 +219,10 @@ def exception(self) -> Optional[BaseException]:
if self._main_future is None:
raise RuntimeError('Task manager context not entered.')
if not self._main_future.done():
raise RuntimeError('Task manager main thread not done; call should be '
'conditioned on `done` returning `True`.')
raise RuntimeError(
'Task manager main thread not done; call should be '
'conditioned on `done` returning `True`.'
)
return self._main_future.exception()

def _main(self) -> None:
Expand Down Expand Up @@ -271,7 +276,8 @@ def _handle_exec_node_task(self, task: task_lib.ExecNodeTask) -> None:
if node_uid in self._scheduler_by_node_uid:
raise RuntimeError(
'Cannot create multiple task schedulers for the same task; '
'task_id: {}'.format(task.task_id))
'task_id: {}'.format(task.task_id)
)
scheduler = _SchedulerWrapper(
typing.cast(
ts.TaskScheduler[task_lib.ExecNodeTask],
Expand All @@ -294,13 +300,16 @@ def _handle_cancel_node_task(self, task: task_lib.CancelNodeTask) -> None:
if scheduler is None:
logging.info(
'No task scheduled for node uid: %s. The task might have already '
'completed before it could be cancelled.', task.node_uid)
'completed before it could be cancelled.',
task.node_uid,
)
else:
scheduler.cancel(cancel_task=task)
self._task_queue.task_done(task)

def _process_exec_node_task(self, scheduler: _SchedulerWrapper,
task: task_lib.ExecNodeTask) -> None:
def _process_exec_node_task(
self, scheduler: _SchedulerWrapper, task: task_lib.ExecNodeTask
) -> None:
"""Processes an `ExecNodeTask` using the given task scheduler."""
# This is a blocking call to the scheduler which can take a long time to
# complete for some types of task schedulers. The scheduler is expected to
Expand All @@ -318,7 +327,7 @@ def _process_exec_node_task(self, scheduler: _SchedulerWrapper,
code=status_lib.Code.UNKNOWN,
message=''.join(
traceback.format_exception(*sys.exc_info(), limit=1),
)
),
)
result = ts.TaskSchedulerResult(status=status)
logging.info(
Expand Down Expand Up @@ -414,5 +423,6 @@ def _cleanup(self, final: bool = False) -> None:
'Exception %d (out of %d):',
i,
len(exceptions),
exc_info=(type(e), e, e.__traceback__))
exc_info=(type(e), e, e.__traceback__),
)
raise TasksProcessingError(exceptions)

0 comments on commit 21f0d53

Please sign in to comment.