Skip to content

Commit

Permalink
Refactor the _get_pipeline_and_node function.
Browse files Browse the repository at this point in the history
Moved the function from pipeline_ops_extensions.py to pipeline_state.py
and update it's function calls in pipeline_ops_extensions.py accordingly.
Also added tests for _get_pipeline_and_node function in the pipeline_state_test.py.

PiperOrigin-RevId: 643421872
  • Loading branch information
tfx-copybara committed Jun 14, 2024
1 parent 74ba85e commit 78441be
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 2 deletions.
45 changes: 45 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1669,3 +1669,48 @@ def _get_sub_pipeline_ids_from_pipeline_info(
sub_pipeline_ids = pipeline_info.parent_ids[1:]
sub_pipeline_ids.append(pipeline_info.id)
return sub_pipeline_ids


def get_pipeline_and_node(
mlmd_handle: metadata.Metadata,
node_uid: task_lib.NodeUid,
pipeline_run_id: str,
) -> tuple[pipeline_pb2.Pipeline, node_proto_view.PipelineNodeProtoView]:
"""Gets the pipeline and node for the node_uid.
This function is experimental, and should only be used when publishing
external and intermediate artifacts.
Args:
mlmd_handle: A handle to the MLMD db.
node_uid: Node uid of the node to get.
pipeline_run_id: Run id of the pipeline for the synchronous pipeline.
Returns:
A tuple with the pipeline and node proto view for the node_uid.
"""
with PipelineState.load(mlmd_handle, node_uid.pipeline_uid) as pipeline_state:
if (
pipeline_run_id or pipeline_state.pipeline_run_id
) and pipeline_run_id != pipeline_state.pipeline_run_id:
raise status_lib.StatusNotOkError(
code=status_lib.Code.NOT_FOUND,
message=(
'Unable to find an active pipeline run for pipeline_run_id: '
f'{pipeline_run_id}'
),
)
nodes = node_proto_view.get_view_for_all_in(pipeline_state.pipeline)
filtered_nodes = [n for n in nodes if n.node_info.id == node_uid.node_id]
if len(filtered_nodes) != 1:
raise status_lib.StatusNotOkError(
code=status_lib.Code.NOT_FOUND,
message=f'unable to find node: {node_uid}',
)
node = filtered_nodes[0]
if not isinstance(node, node_proto_view.PipelineNodeProtoView):
raise ValueError(
f'Unexpected type for node {node.node_info.id}. Only '
'pipeline nodes are supported for external executions.'
)
return (pipeline_state.pipeline, node)
36 changes: 34 additions & 2 deletions tfx/orchestration/experimental/core/pipeline_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from tfx.proto.orchestration import run_state_pb2
from tfx.utils import json_utils
from tfx.utils import status as status_lib

import ml_metadata as mlmd
from ml_metadata.proto import metadata_store_pb2

Expand Down Expand Up @@ -1546,6 +1545,40 @@ def test_create_and_load_concurrent_pipeline_runs(self):
pipeline_state_run1.pipeline_uid,
)

def test_get_pipeline_and_node(self):
with TestEnv(None, 20000), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
pipeline_nodes=['ExampleGen', 'Trainer'],
pipeline_run_id='run0',
)
pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
trainer_node_uid = task_lib.NodeUid(pipeline_uid, 'Trainer')
pstate.PipelineState.new(m, pipeline)
ir, npv = pstate.get_pipeline_and_node(m, trainer_node_uid, 'run0')
self.assertEqual(npv.node_info.id, 'Trainer')
self.assertEqual(
pipeline.pipeline_info,
ir.pipeline_info,
)

def test_get_pipeline_and_node_not_found(self):
with TestEnv(None, 20000), self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1',
execution_mode=pipeline_pb2.Pipeline.SYNC,
pipeline_nodes=['ExampleGen', 'Trainer'],
pipeline_run_id='run0',
)
with pstate.PipelineState.new(m, pipeline) as pipeline_state:
node_uid = task_lib.NodeUid(
pipeline_uid=pipeline_state.pipeline_uid, node_id='NodeDoesNotExist'
)

with self.assertRaises(status_lib.StatusNotOkError):
pstate.get_pipeline_and_node(m, node_uid, 'run0')


class NodeStatesProxyTest(test_utils.TfxTest):

Expand Down Expand Up @@ -1632,6 +1665,5 @@ def test_save_with_max_str_len(self):
json_utils.dumps(node_states),
)


if __name__ == '__main__':
tf.test.main()

0 comments on commit 78441be

Please sign in to comment.