From 78441beee22d8adbf40f29e3a32349a06cac92e7 Mon Sep 17 00:00:00 2001 From: tfx-team Date: Fri, 14 Jun 2024 12:23:51 -0700 Subject: [PATCH] Refactor the _get_pipeline_and_node function. Moved the function from pipeline_ops_extensions.py to pipeline_state.py and update it's function calls in pipeline_ops_extensions.py accordingly. Also added tests for _get_pipeline_and_node function in the pipeline_state_test.py. PiperOrigin-RevId: 643421872 --- .../experimental/core/pipeline_state.py | 45 +++++++++++++++++++ .../experimental/core/pipeline_state_test.py | 36 ++++++++++++++- 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/tfx/orchestration/experimental/core/pipeline_state.py b/tfx/orchestration/experimental/core/pipeline_state.py index 32139c5e62..7a236d6c2a 100644 --- a/tfx/orchestration/experimental/core/pipeline_state.py +++ b/tfx/orchestration/experimental/core/pipeline_state.py @@ -1669,3 +1669,48 @@ def _get_sub_pipeline_ids_from_pipeline_info( sub_pipeline_ids = pipeline_info.parent_ids[1:] sub_pipeline_ids.append(pipeline_info.id) return sub_pipeline_ids + + +def get_pipeline_and_node( + mlmd_handle: metadata.Metadata, + node_uid: task_lib.NodeUid, + pipeline_run_id: str, +) -> tuple[pipeline_pb2.Pipeline, node_proto_view.PipelineNodeProtoView]: + """Gets the pipeline and node for the node_uid. + + This function is experimental, and should only be used when publishing + external and intermediate artifacts. + + Args: + mlmd_handle: A handle to the MLMD db. + node_uid: Node uid of the node to get. + pipeline_run_id: Run id of the pipeline for the synchronous pipeline. + + Returns: + A tuple with the pipeline and node proto view for the node_uid. + """ + with PipelineState.load(mlmd_handle, node_uid.pipeline_uid) as pipeline_state: + if ( + pipeline_run_id or pipeline_state.pipeline_run_id + ) and pipeline_run_id != pipeline_state.pipeline_run_id: + raise status_lib.StatusNotOkError( + code=status_lib.Code.NOT_FOUND, + message=( + 'Unable to find an active pipeline run for pipeline_run_id: ' + f'{pipeline_run_id}' + ), + ) + nodes = node_proto_view.get_view_for_all_in(pipeline_state.pipeline) + filtered_nodes = [n for n in nodes if n.node_info.id == node_uid.node_id] + if len(filtered_nodes) != 1: + raise status_lib.StatusNotOkError( + code=status_lib.Code.NOT_FOUND, + message=f'unable to find node: {node_uid}', + ) + node = filtered_nodes[0] + if not isinstance(node, node_proto_view.PipelineNodeProtoView): + raise ValueError( + f'Unexpected type for node {node.node_info.id}. Only ' + 'pipeline nodes are supported for external executions.' + ) + return (pipeline_state.pipeline, node) diff --git a/tfx/orchestration/experimental/core/pipeline_state_test.py b/tfx/orchestration/experimental/core/pipeline_state_test.py index 8d4bfcdcf2..cc6fd85056 100644 --- a/tfx/orchestration/experimental/core/pipeline_state_test.py +++ b/tfx/orchestration/experimental/core/pipeline_state_test.py @@ -37,7 +37,6 @@ from tfx.proto.orchestration import run_state_pb2 from tfx.utils import json_utils from tfx.utils import status as status_lib - import ml_metadata as mlmd from ml_metadata.proto import metadata_store_pb2 @@ -1546,6 +1545,40 @@ def test_create_and_load_concurrent_pipeline_runs(self): pipeline_state_run1.pipeline_uid, ) + def test_get_pipeline_and_node(self): + with TestEnv(None, 20000), self._mlmd_connection as m: + pipeline = _test_pipeline( + 'pipeline1', + execution_mode=pipeline_pb2.Pipeline.SYNC, + pipeline_nodes=['ExampleGen', 'Trainer'], + pipeline_run_id='run0', + ) + pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) + trainer_node_uid = task_lib.NodeUid(pipeline_uid, 'Trainer') + pstate.PipelineState.new(m, pipeline) + ir, npv = pstate.get_pipeline_and_node(m, trainer_node_uid, 'run0') + self.assertEqual(npv.node_info.id, 'Trainer') + self.assertEqual( + pipeline.pipeline_info, + ir.pipeline_info, + ) + + def test_get_pipeline_and_node_not_found(self): + with TestEnv(None, 20000), self._mlmd_connection as m: + pipeline = _test_pipeline( + 'pipeline1', + execution_mode=pipeline_pb2.Pipeline.SYNC, + pipeline_nodes=['ExampleGen', 'Trainer'], + pipeline_run_id='run0', + ) + with pstate.PipelineState.new(m, pipeline) as pipeline_state: + node_uid = task_lib.NodeUid( + pipeline_uid=pipeline_state.pipeline_uid, node_id='NodeDoesNotExist' + ) + + with self.assertRaises(status_lib.StatusNotOkError): + pstate.get_pipeline_and_node(m, node_uid, 'run0') + class NodeStatesProxyTest(test_utils.TfxTest): @@ -1632,6 +1665,5 @@ def test_save_with_max_str_len(self): json_utils.dumps(node_states), ) - if __name__ == '__main__': tf.test.main()