Skip to content

Commit

Permalink
no-op
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 645442735
  • Loading branch information
tfx-copybara committed Jun 21, 2024
1 parent f94b329 commit 34fd64b
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 2 deletions.
20 changes: 20 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from tfx.dsl.io import filesystem
from tfx.orchestration import metadata
from tfx.orchestration import node_proto_view
from tfx.orchestration import subpipeline_utils
from tfx.orchestration.experimental.core import async_pipeline_task_gen
from tfx.orchestration.experimental.core import constants
from tfx.orchestration.experimental.core import env
Expand Down Expand Up @@ -1252,6 +1253,20 @@ def filter_by_pipeline_uid(
return lambda p: p.pipeline_uid == pipeline_uid


def _record_orchestration_time(pipeline_state: pstate.PipelineState) -> None:
"""Records an orchestration time for the pipeline run."""
# We only care about orchestration time for root pipelines, skip any
# subpipelines.
if subpipeline_utils.is_subpipeline(pipeline_state.pipeline):
return
pipeline_run_id = pipeline_state.pipeline_run_id
# Backend expects an empty string for the pipeline run id, for ASYNC pipeline
# runs.
if pipeline_run_id is None:
pipeline_run_id = ''
env.get_env().record_orchestration_time(pipeline_run_id)


@_pipeline_op()
def orchestrate(
mlmd_connection_manager: mlmd_cm.MLMDConnectionManager,
Expand Down Expand Up @@ -1322,6 +1337,7 @@ def orchestrate(
service_job_manager,
pipeline_state,
)
_record_orchestration_time(pipeline_state)
except Exception: # pylint: disable=broad-except
# If orchestrating a stop-initiated pipeline raises an exception, we log
# the exception but do not re-raise since we do not want to crash the
Expand All @@ -1345,6 +1361,7 @@ def orchestrate(
service_job_manager,
pipeline_state,
)
_record_orchestration_time(pipeline_state)
except Exception as e: # pylint: disable=broad-except
logging.exception(
'Exception raised while orchestrating update-initiated pipeline %s',
Expand All @@ -1364,6 +1381,7 @@ def orchestrate(
),
)
)
_record_orchestration_time(pipeline_state)
except Exception: # pylint: disable=broad-except
# If stop initiation also raised an exception , we log the exception but
# do not re-raise since we do not want to crash the orchestrator. If
Expand All @@ -1387,6 +1405,7 @@ def orchestrate(
service_job_manager,
pipeline_state,
)
_record_orchestration_time(pipeline_state)
except Exception as e: # pylint: disable=broad-except
logging.exception(
'Exception raised while orchestrating active pipeline %s',
Expand All @@ -1404,6 +1423,7 @@ def orchestrate(
message=f'Error orchestrating active pipeline: {str(e)}',
)
)
_record_orchestration_time(pipeline_state)
except Exception: # pylint: disable=broad-except
# If stop initiation also raised an exception , we log the exception but
# do not re-raise since we do not want to crash the orchestrator. If
Expand Down
103 changes: 101 additions & 2 deletions tfx/orchestration/experimental/core/pipeline_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,10 +1434,73 @@ def test_stop_node_wait_for_inactivation_timeout(self):
(pstate.NodeState.STOPPING, pstate.NodeState.STOPPED),
)

@parameterized.named_parameters(
dict(
testcase_name='async',
pipeline=_test_pipeline('pipeline1'),
expected_run_id='',
),
dict(
testcase_name='sync',
pipeline=_test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC),
expected_run_id='run0',
),
)
def test_record_orchestration_time(self, pipeline, expected_run_id):
with self._mlmd_cm as mlmd_connection_manager:
m = mlmd_connection_manager.primary_mlmd_handle
pipeline_ops.initiate_pipeline_start(m, pipeline)
environment = env.get_env()
with mock.patch.object(
environment,
'record_orchestration_time',
wraps=environment.record_orchestration_time,
) as mock_env_record_orchestration_time:
task_queue = tq.TaskQueue()
pipeline_ops.orchestrate(
mlmd_connection_manager,
task_queue,
self._mock_service_job_manager,
)
mock_env_record_orchestration_time.assert_called_with(expected_run_id)

def test_record_orchestration_time_subpipeline(self):
with self._mlmd_cm as mlmd_connection_manager:
m = mlmd_connection_manager.primary_mlmd_handle
pipeline = test_sync_pipeline.create_pipeline_with_subpipeline()
runtime_parameter_utils.substitute_runtime_parameter(
pipeline,
{
constants.PIPELINE_RUN_ID_PARAMETER_NAME: 'run0',
},
)
pipeline_ops.initiate_pipeline_start(m, pipeline)
environment = env.get_env()
with mock.patch.object(
environment,
'record_orchestration_time',
wraps=environment.record_orchestration_time,
) as mock_env_record_orchestration_time:
task_queue = tq.TaskQueue()
pipeline_ops.orchestrate(
mlmd_connection_manager,
task_queue,
self._mock_service_job_manager,
)
mock_env_record_orchestration_time.assert_called_with('run0')

@mock.patch.object(sync_pipeline_task_gen, 'SyncPipelineTaskGenerator')
@mock.patch.object(async_pipeline_task_gen, 'AsyncPipelineTaskGenerator')
@mock.patch.object(
pipeline_ops,
'_record_orchestration_time',
wraps=pipeline_ops._record_orchestration_time,
)
def test_orchestrate_active_pipelines(
self, mock_async_task_gen, mock_sync_task_gen
self,
mock_record_orchestration_time,
mock_async_task_gen,
mock_sync_task_gen,
):
with self._mlmd_cm as mlmd_connection_manager:
m = mlmd_connection_manager.primary_mlmd_handle
Expand Down Expand Up @@ -1509,6 +1572,15 @@ def test_orchestrate_active_pipelines(
service_jobs.DummyServiceJobManager(),
)

# Check that the orchestration time was recorded four times. Once for each
# of the four pipelines.
mock_record_orchestration_time.assert_has_calls([
mock.call(mock.ANY),
mock.call(mock.ANY),
mock.call(mock.ANY),
mock.call(mock.ANY),
])

self.assertEqual(2, mock_async_task_gen.return_value.generate.call_count)
self.assertEqual(2, mock_sync_task_gen.return_value.generate.call_count)

Expand Down Expand Up @@ -1550,9 +1622,15 @@ def test_orchestrate_active_pipelines(
@mock.patch.object(
task_gen_utils, 'generate_cancel_task_from_running_execution'
)
@mock.patch.object(
pipeline_ops,
'_record_orchestration_time',
wraps=pipeline_ops._record_orchestration_time,
)
def test_orchestrate_stop_initiated_pipelines(
self,
pipeline,
mock_record_orchestration_time,
mock_gen_task_from_active,
mock_async_task_gen,
mock_sync_task_gen,
Expand Down Expand Up @@ -1617,6 +1695,10 @@ def recorder(event):
self._mock_service_job_manager,
)
)
# We should have recorded the orchestration time once, for one pipeline.
# We reset after to verify this is true throughout.
mock_record_orchestration_time.assert_called_once()
mock_record_orchestration_time.reset_mock()

# PipelineFinished event should not trigger since not all the nodes are
# stopped.
Expand Down Expand Up @@ -1683,6 +1765,8 @@ def recorder(event):
self._mock_service_job_manager,
)
)
mock_record_orchestration_time.assert_called_once()
mock_record_orchestration_time.reset_mock()
self.assertTrue(task_queue.is_empty())
[execution] = m.store.get_executions_by_id([pipeline_execution_id])
self.assertEqual(
Expand Down Expand Up @@ -1721,6 +1805,7 @@ def recorder(event):
self._mock_service_job_manager,
)
)
mock_record_orchestration_time.assert_not_called()

@mock.patch.object(
task_gen_utils, 'generate_cancel_task_from_running_execution'
Expand Down Expand Up @@ -1890,7 +1975,14 @@ def recorder(event):
_test_pipeline('pipeline1'),
_test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC),
)
def test_orchestrate_update_initiated_pipelines(self, pipeline):
@mock.patch.object(
pipeline_ops,
'_record_orchestration_time',
wraps=pipeline_ops._record_orchestration_time,
)
def test_orchestrate_update_initiated_pipelines(
self, pipeline, mock_record_orchestration_time
):
with self._mlmd_cm as mlmd_connection_manager:
m = mlmd_connection_manager.primary_mlmd_handle
pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen'
Expand Down Expand Up @@ -1924,6 +2016,10 @@ def test_orchestrate_update_initiated_pipelines(self, pipeline):
pipeline_ops.orchestrate(
mlmd_connection_manager, task_queue, self._mock_service_job_manager
)
# We should have recorded the orchestration time once, for one pipeline.
# We reset after to verify this is true throughout.
mock_record_orchestration_time.assert_called_once()
mock_record_orchestration_time.reset_mock()
# stop_node_services should be called for ExampleGen.
self._mock_service_job_manager.stop_node_services.assert_has_calls(
[mock.call(mock.ANY, 'ExampleGen')]
Expand Down Expand Up @@ -1954,6 +2050,9 @@ def test_orchestrate_update_initiated_pipelines(self, pipeline):
self._mock_service_job_manager.stop_node_services.assert_has_calls(
[mock.call(mock.ANY, 'Transform')]
)
# Check that the orchestration time was recorded again.
mock_record_orchestration_time.assert_called_once()
mock_record_orchestration_time.reset_mock()

# Check that the node states are STARTING.
[execution] = m.store.get_executions_by_id([pipeline_state.execution_id])
Expand Down

0 comments on commit 34fd64b

Please sign in to comment.