diff --git a/tfx/orchestration/experimental/core/pipeline_state.py b/tfx/orchestration/experimental/core/pipeline_state.py index 945553ae58..9be76a4792 100644 --- a/tfx/orchestration/experimental/core/pipeline_state.py +++ b/tfx/orchestration/experimental/core/pipeline_state.py @@ -14,6 +14,7 @@ """Pipeline state management functionality.""" import base64 +import collections import contextlib import copy import dataclasses @@ -515,6 +516,34 @@ def new( Raises: status_lib.StatusNotOkError: If a pipeline with same UID already exists. """ + num_subpipelines = 0 + to_process = collections.deque([pipeline]) + while to_process: + p = to_process.popleft() + for node in p.nodes: + if node.WhichOneof('node') == 'sub_pipeline': + num_subpipelines += 1 + to_process.append(node.sub_pipeline) + # If the number of active task schedulers is less than the maximum number of + # active task schedulers, subpipelines may not work. + # This is because when scheduling the subpipeline, the start node + # and end node will be scheduled immediately, potentially causing contention + # where the end node is waiting on some intermediary node to finish, but the + # intermediary node cannot be scheduled as the end node is running. + # Note that this number is an overestimate - in reality if subpipelines are + # dependent on each other we may not need so many task schedulers. + max_task_schedulers = env.get_env().maximum_active_task_schedulers() + if max_task_schedulers < num_subpipelines: + raise status_lib.StatusNotOkError( + code=status_lib.Code.FAILED_PRECONDITION, + message=( + f'The maximum number of task schedulers ({max_task_schedulers})' + f' is less than the number of subpipelines ({num_subpipelines}).' + ' Please set the maximum number of task schedulers to at least' + f' {num_subpipelines} in' + ' OrchestrationOptions.max_running_components.' + ), + ) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) context = context_lib.register_context_if_not_exists( mlmd_handle, diff --git a/tfx/orchestration/experimental/core/pipeline_state_test.py b/tfx/orchestration/experimental/core/pipeline_state_test.py index 857573c7f5..b7e02cb0e4 100644 --- a/tfx/orchestration/experimental/core/pipeline_state_test.py +++ b/tfx/orchestration/experimental/core/pipeline_state_test.py @@ -15,8 +15,9 @@ import dataclasses import os +import sys import time -from typing import List +from typing import List, Optional from unittest import mock from absl.testing import parameterized @@ -36,6 +37,7 @@ from tfx.proto.orchestration import run_state_pb2 from tfx.utils import json_utils from tfx.utils import status as status_lib + import ml_metadata as mlmd from ml_metadata.proto import metadata_store_pb2 @@ -155,9 +157,19 @@ def test_node_state_json(self): class TestEnv(env._DefaultEnv): - def __init__(self, base_dir, max_str_len): + def __init__( + self, + *, + base_dir: Optional[str], + max_str_len: int, + max_task_schedulers: int + ): self.base_dir = base_dir self.max_str_len = max_str_len + self.max_task_schedulers = max_task_schedulers + + def maximum_active_task_schedulers(self) -> int: + return self.max_task_schedulers def get_base_dir(self): return self.base_dir @@ -216,7 +228,9 @@ def test_new_pipeline_state(self): self.assertTrue(pstate._active_owned_pipelines_exist) def test_new_pipeline_state_with_sub_pipelines(self): - with self._mlmd_connection as m: + with TestEnv( + base_dir=None, max_str_len=20000, max_task_schedulers=2 + ), self._mlmd_connection as m: pstate._active_owned_pipelines_exist = False pipeline = _test_pipeline('pipeline1') # Add 2 additional layers of sub pipelines. Note that there is no normal @@ -276,6 +290,35 @@ def test_new_pipeline_state_with_sub_pipelines(self): ], ) + def test_new_pipeline_state_with_sub_pipelines_fails_when_not_enough_task_schedulers( + self, + ): + with TestEnv( + base_dir=None, max_str_len=20000, max_task_schedulers=1 + ), self._mlmd_connection as m: + pstate._active_owned_pipelines_exist = False + pipeline = _test_pipeline('pipeline1') + # Add 2 additional layers of sub pipelines. Note that there is no normal + # pipeline node in the first pipeline layer. + _add_sub_pipeline( + pipeline, + 'sub_pipeline1', + sub_pipeline_nodes=['Trainer'], + sub_pipeline_run_id='sub_pipeline1_run0', + ) + _add_sub_pipeline( + pipeline.nodes[0].sub_pipeline, + 'sub_pipeline2', + sub_pipeline_nodes=['Trainer'], + sub_pipeline_run_id='sub_pipeline1_sub_pipeline2_run0', + ) + with self.assertRaisesRegex( + status_lib.StatusNotOkError, + 'The maximum number of task schedulers', + ) as e: + pstate.PipelineState.new(m, pipeline) + self.assertEqual(e.exception.code, status_lib.Code.FAILED_PRECONDITION) + def test_load_pipeline_state(self): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1', pipeline_nodes=['Trainer']) @@ -770,7 +813,9 @@ def test_initiate_node_start_stop(self, mock_time): def recorder(event): events.append(event) - with TestEnv(None, 2000), event_observer.init(), self._mlmd_connection as m: + with TestEnv( + base_dir=None, max_str_len=2000, max_task_schedulers=sys.maxsize + ), event_observer.init(), self._mlmd_connection as m: event_observer.register_observer(recorder) pipeline = _test_pipeline('pipeline1', pipeline_nodes=['Trainer']) @@ -900,7 +945,9 @@ def recorder(event): @mock.patch.object(pstate, 'time') def test_get_node_states_dict(self, mock_time): mock_time.time.return_value = time.time() - with TestEnv(None, 20000), self._mlmd_connection as m: + with TestEnv( + base_dir=None, max_str_len=20000, max_task_schedulers=sys.maxsize + ), self._mlmd_connection as m: pipeline = _test_pipeline( 'pipeline1', execution_mode=pipeline_pb2.Pipeline.SYNC, @@ -1120,7 +1167,9 @@ def test_pipeline_view_get_pipeline_run_state(self, mock_time): @mock.patch.object(pstate, 'time') def test_pipeline_view_get_node_run_states(self, mock_time): mock_time.time.return_value = time.time() - with TestEnv(None, 20000), self._mlmd_connection as m: + with TestEnv( + base_dir=None, max_str_len=20000, max_task_schedulers=sys.maxsize + ), self._mlmd_connection as m: pipeline = _test_pipeline( 'pipeline1', execution_mode=pipeline_pb2.Pipeline.SYNC, @@ -1205,7 +1254,9 @@ def test_pipeline_view_get_node_run_states(self, mock_time): @mock.patch.object(pstate, 'time') def test_pipeline_view_get_node_run_state_history(self, mock_time): mock_time.time.return_value = time.time() - with TestEnv(None, 20000), self._mlmd_connection as m: + with TestEnv( + base_dir=None, max_str_len=20000, max_task_schedulers=sys.maxsize + ), self._mlmd_connection as m: pipeline = _test_pipeline( 'pipeline1', execution_mode=pipeline_pb2.Pipeline.SYNC, @@ -1252,7 +1303,9 @@ def test_node_state_for_skipped_nodes_in_partial_pipeline_run( ): """Tests that nodes marked to be skipped have the right node state and previous node state.""" mock_time.time.return_value = time.time() - with TestEnv(None, 20000), self._mlmd_connection as m: + with TestEnv( + base_dir=None, max_str_len=20000, max_task_schedulers=sys.maxsize + ), self._mlmd_connection as m: pipeline = _test_pipeline( 'pipeline1', execution_mode=pipeline_pb2.Pipeline.SYNC, @@ -1371,7 +1424,9 @@ def test_load_all_with_list_options(self): def test_get_previous_node_run_states_for_skipped_nodes(self, mock_time): """Tests that nodes marked to be skipped have the right previous run state.""" mock_time.time.return_value = time.time() - with TestEnv(None, 20000), self._mlmd_connection as m: + with TestEnv( + base_dir=None, max_str_len=20000, max_task_schedulers=sys.maxsize + ), self._mlmd_connection as m: pipeline = _test_pipeline( 'pipeline1', execution_mode=pipeline_pb2.Pipeline.SYNC, @@ -1498,7 +1553,9 @@ def test_create_and_load_concurrent_pipeline_runs(self): ) def test_get_pipeline_and_node(self): - with TestEnv(None, 20000), self._mlmd_connection as m: + with TestEnv( + base_dir=None, max_str_len=20000, max_task_schedulers=sys.maxsize + ), self._mlmd_connection as m: pipeline = _test_pipeline( 'pipeline1', execution_mode=pipeline_pb2.Pipeline.SYNC, @@ -1516,7 +1573,9 @@ def test_get_pipeline_and_node(self): ) def test_get_pipeline_and_node_not_found(self): - with TestEnv(None, 20000), self._mlmd_connection as m: + with TestEnv( + base_dir=None, max_str_len=20000, max_task_schedulers=sys.maxsize + ), self._mlmd_connection as m: pipeline = _test_pipeline( 'pipeline1', execution_mode=pipeline_pb2.Pipeline.SYNC, @@ -1594,7 +1653,9 @@ def test_save_with_max_str_len(self): state=pstate.NodeState.COMPLETE, ) } - with TestEnv(None, 20): + with TestEnv( + base_dir=None, max_str_len=20, max_task_schedulers=sys.maxsize + ): execution = metadata_store_pb2.Execution() proxy = pstate._NodeStatesProxy(execution) proxy.set(node_states) @@ -1605,7 +1666,9 @@ def test_save_with_max_str_len(self): ), json_utils.dumps(node_states_without_state_history), ) - with TestEnv(None, 2000): + with TestEnv( + base_dir=None, max_str_len=2000, max_task_schedulers=sys.maxsize + ): execution = metadata_store_pb2.Execution() proxy = pstate._NodeStatesProxy(execution) proxy.set(node_states)