From ec6be7253f84be7da5016a0b822b385e7f5f3262 Mon Sep 17 00:00:00 2001 From: kmonte Date: Thu, 13 Jun 2024 11:35:59 -0700 Subject: [PATCH] Fully swap to using node_proto_view.get_view_for_all_in PiperOrigin-RevId: 643069039 --- .../experimental/core/pipeline_ops.py | 21 ++++---- .../experimental/core/pipeline_state.py | 49 +++++++------------ .../experimental/core/sample_mlmd_creator.py | 6 +-- tfx/orchestration/node_proto_view.py | 1 - 4 files changed, 33 insertions(+), 44 deletions(-) diff --git a/tfx/orchestration/experimental/core/pipeline_ops.py b/tfx/orchestration/experimental/core/pipeline_ops.py index 8c07f609777..19a4bba68b3 100644 --- a/tfx/orchestration/experimental/core/pipeline_ops.py +++ b/tfx/orchestration/experimental/core/pipeline_ops.py @@ -459,7 +459,7 @@ def _check_nodes_exist( ) -> None: """Raises an error if node_uid does not exist in the pipeline.""" node_id_set = set(n.node_id for n in node_uids) - nodes = pstate.get_all_nodes(pipeline) + nodes = node_proto_view.get_view_for_all_in(pipeline) filtered_nodes = [n for n in nodes if n.node_info.id in node_id_set] if len(filtered_nodes) != len(node_id_set): raise status_lib.StatusNotOkError( @@ -570,7 +570,7 @@ def resume_manual_node( mlmd_handle, node_uid.pipeline_uid ) as pipeline_state: env.get_env().check_if_can_orchestrate(pipeline_state.pipeline) - nodes = pstate.get_all_nodes(pipeline_state.pipeline) + nodes = node_proto_view.get_view_for_all_in(pipeline_state.pipeline) filtered_nodes = [n for n in nodes if n.node_info.id == node_uid.node_id] if len(filtered_nodes) != 1: raise status_lib.StatusNotOkError( @@ -959,7 +959,8 @@ def resume_pipeline( if node_state.is_success(): previously_succeeded_nodes.append(node) pipeline_nodes = [ - node.node_info.id for node in pstate.get_all_nodes(pipeline) + node.node_info.id + for node in node_proto_view.get_view_for_all_in(pipeline) ] # Mark nodes using partial pipeline run lib. @@ -1005,7 +1006,7 @@ def _recursively_revive_pipelines( ) -> pstate.PipelineState: """Recursively revives all pipelines, resuing executions if present.""" with pipeline_state: - nodes = pstate.get_all_nodes(pipeline_state.pipeline) + nodes = node_proto_view.get_view_for_all_in(pipeline_state.pipeline) node_by_name = {node.node_info.id: node for node in nodes} # TODO(b/272015049): Add support for manager start nodes. nodes_to_start = [ @@ -1510,7 +1511,7 @@ def _run_end_nodes( # Build some dicts and find all paired nodes end_nodes = [] pipeline = pipeline_state.pipeline - nodes = pstate.get_all_nodes(pipeline) + nodes = node_proto_view.get_view_for_all_in(pipeline) node_uid_by_id = {} with pipeline_state: node_state_by_node_uid = pipeline_state.get_node_states_dict() @@ -1626,7 +1627,7 @@ def _orchestrate_stop_initiated_pipeline( pipeline = pipeline_state.pipeline stop_reason = pipeline_state.stop_initiated_reason() assert stop_reason is not None - for node in pstate.get_all_nodes(pipeline): + for node in node_proto_view.get_view_for_all_in(pipeline): node_uid = task_lib.NodeUid.from_node(pipeline, node) with pipeline_state.node_state_update_context(node_uid) as node_state: if node_state.is_stoppable(): @@ -1683,7 +1684,7 @@ def _orchestrate_stop_initiated_pipeline( ) if any( n.execution_options.HasField('resource_lifetime') - for n in pstate.get_all_nodes(pipeline_state.pipeline) + for n in node_proto_view.get_view_for_all_in(pipeline_state.pipeline) ): logging.info('Pipeline has paired nodes. May launch additional jobs') # Note that this is a pretty hacky "best effort" attempt at cleanup, we @@ -1725,7 +1726,7 @@ def _orchestrate_update_initiated_pipeline( else None ) pipeline = pipeline_state.pipeline - for node in pstate.get_all_nodes(pipeline): + for node in node_proto_view.get_view_for_all_in(pipeline): # TODO(b/217584342): Partial reload which excludes service nodes is not # fully supported in async pipelines since we don't have a mechanism to # reload them later for new executions. @@ -1774,7 +1775,7 @@ def _orchestrate_update_initiated_pipeline( if all_stopped: with pipeline_state: pipeline = pipeline_state.pipeline - for node in pstate.get_all_nodes(pipeline): + for node in node_proto_view.get_view_for_all_in(pipeline): # TODO(b/217584342): Partial reload which excludes service nodes is not # fully supported in async pipelines since we don't have a mechanism to # reload them later for new executions. @@ -2001,7 +2002,7 @@ def _filter_by_node_id( def _get_node_infos(pipeline_state: pstate.PipelineState) -> List[_NodeInfo]: """Returns a list of `_NodeInfo` object for each node in the pipeline.""" - nodes = pstate.get_all_nodes(pipeline_state.pipeline) + nodes = node_proto_view.get_view_for_all_in(pipeline_state.pipeline) result: List[_NodeInfo] = [] with pipeline_state: for node in nodes: diff --git a/tfx/orchestration/experimental/core/pipeline_state.py b/tfx/orchestration/experimental/core/pipeline_state.py index 7a236d6c2a2..9db976639d0 100644 --- a/tfx/orchestration/experimental/core/pipeline_state.py +++ b/tfx/orchestration/experimental/core/pipeline_state.py @@ -44,7 +44,6 @@ from tfx.proto.orchestration import metadata_pb2 from tfx.proto.orchestration import pipeline_pb2 from tfx.proto.orchestration import run_state_pb2 -from tfx.utils import deprecation_utils from tfx.utils import json_utils from tfx.utils import status as status_lib @@ -969,8 +968,14 @@ def initiate_update( def _structure( pipeline: pipeline_pb2.Pipeline ) -> List[Tuple[str, List[str], List[str]]]: - return [(node.node_info.id, list(node.upstream_nodes), - list(node.downstream_nodes)) for node in get_all_nodes(pipeline)] + return [ + ( + node.node_info.id, + list(node.upstream_nodes), + list(node.downstream_nodes), + ) + for node in node_proto_view.get_view_for_all_in(pipeline) + ] if _structure(self.pipeline) != _structure(updated_pipeline): raise status_lib.StatusNotOkError( @@ -1078,7 +1083,7 @@ def get_node_states_dict(self) -> Dict[task_lib.NodeUid, NodeState]: self._check_context() node_states_dict = self._node_states_proxy.get() result = {} - for node in get_all_nodes(self.pipeline): + for node in node_proto_view.get_view_for_all_in(self.pipeline): node_uid = task_lib.NodeUid.from_node(self.pipeline, node) result[node_uid] = node_states_dict.get(node_uid.node_id, NodeState()) return result @@ -1088,7 +1093,7 @@ def get_previous_node_states_dict(self) -> Dict[task_lib.NodeUid, NodeState]: self._check_context() node_states_dict = self._node_states_proxy.get(_PREVIOUS_NODE_STATES) result = {} - for node in get_all_nodes(self.pipeline): + for node in node_proto_view.get_view_for_all_in(self.pipeline): node_uid = task_lib.NodeUid.from_node(self.pipeline, node) if node_uid.node_id not in node_states_dict: continue @@ -1363,7 +1368,7 @@ def get_node_run_states(self) -> Dict[str, run_state_pb2.RunState]: """Returns a dict mapping node id to current run state.""" result = {} node_states_dict = self._node_states_proxy.get() - for node in get_all_nodes(self.pipeline): + for node in node_proto_view.get_view_for_all_in(self.pipeline): node_state = node_states_dict.get(node.node_info.id, NodeState()) result[node.node_info.id] = node_state.to_run_state() return result @@ -1373,7 +1378,7 @@ def get_node_run_states_history( """Returns the history of node run states and timestamps.""" node_states_dict = self._node_states_proxy.get() result = {} - for node in get_all_nodes(self.pipeline): + for node in node_proto_view.get_view_for_all_in(self.pipeline): node_state = node_states_dict.get(node.node_info.id, NodeState()) result[node.node_info.id] = node_state.to_run_state_history() return result @@ -1382,7 +1387,7 @@ def get_previous_node_run_states(self) -> Dict[str, run_state_pb2.RunState]: """Returns a dict mapping node id to previous run state.""" result = {} node_states_dict = self._node_states_proxy.get(_PREVIOUS_NODE_STATES) - for node in get_all_nodes(self.pipeline): + for node in node_proto_view.get_view_for_all_in(self.pipeline): if node.node_info.id not in node_states_dict: continue node_state = node_states_dict[node.node_info.id] @@ -1394,7 +1399,7 @@ def get_previous_node_run_states_history( """Returns a dict mapping node id to previous run state and timestamps.""" prev_node_states_dict = self._node_states_proxy.get(_PREVIOUS_NODE_STATES) result = {} - for node in get_all_nodes(self.pipeline): + for node in node_proto_view.get_view_for_all_in(self.pipeline): if node.node_info.id not in prev_node_states_dict: continue node_state = prev_node_states_dict[node.node_info.id] @@ -1410,7 +1415,7 @@ def get_node_states_dict(self) -> Dict[str, NodeState]: """Returns a dict mapping node id to node state.""" result = {} node_states_dict = self._node_states_proxy.get() - for node in get_all_nodes(self.pipeline): + for node in node_proto_view.get_view_for_all_in(self.pipeline): result[node.node_info.id] = node_states_dict.get(node.node_info.id, NodeState()) return result @@ -1419,7 +1424,7 @@ def get_previous_node_states_dict(self) -> Dict[str, NodeState]: """Returns a dict mapping node id to node state in previous run.""" result = {} node_states_dict = self._node_states_proxy.get(_PREVIOUS_NODE_STATES) - for node in get_all_nodes(self.pipeline): + for node in node_proto_view.get_view_for_all_in(self.pipeline): if node.node_info.id not in node_states_dict: continue result[node.node_info.id] = node_states_dict[node.node_info.id] @@ -1439,22 +1444,6 @@ def pipeline_id_from_orchestrator_context( return context.name -@deprecation_utils.deprecated( - None, - 'pipeline_state.get_all_nodes has been deprecated in favor of' - ' node_proto_view.get_view_for_all_in which has identical behavior.', -) -@telemetry_utils.noop_telemetry(metrics_utils.no_op_metrics) -def get_all_nodes( - pipeline: pipeline_pb2.Pipeline) -> List[node_proto_view.NodeProtoView]: - """Returns the views of nodes or inner pipelines in the given pipeline.""" - # TODO(goutham): Handle system nodes. - return [ - node_proto_view.get_view(pipeline_or_node) - for pipeline_or_node in pipeline.nodes - ] - - @telemetry_utils.noop_telemetry(metrics_utils.no_op_metrics) def get_all_node_executions( pipeline: pipeline_pb2.Pipeline, @@ -1484,7 +1473,7 @@ def get_all_node_executions( node.node_info.id: task_gen_utils.get_executions( mlmd_handle, node, additional_filters=additional_filters ) - for node in get_all_nodes(pipeline) + for node in node_proto_view.get_view_for_all_in(pipeline) } @@ -1528,7 +1517,7 @@ def get_all_node_artifacts( def _is_node_uid_in_pipeline(node_uid: task_lib.NodeUid, pipeline: pipeline_pb2.Pipeline) -> bool: """Returns `True` if the `node_uid` belongs to the given pipeline.""" - for node in get_all_nodes(pipeline): + for node in node_proto_view.get_view_for_all_in(pipeline): if task_lib.NodeUid.from_node(pipeline, node) == node_uid: return True return False @@ -1593,7 +1582,7 @@ def _save_skipped_node_states(pipeline: pipeline_pb2.Pipeline, if reused_pipeline_view else {} ) - for node in get_all_nodes(pipeline): + for node in node_proto_view.get_view_for_all_in(pipeline): node_id = node.node_info.id if node.execution_options.HasField('skip'): logging.info('Node %s is skipped in this partial run.', node_id) diff --git a/tfx/orchestration/experimental/core/sample_mlmd_creator.py b/tfx/orchestration/experimental/core/sample_mlmd_creator.py index d41acc0af66..217d89c0f05 100644 --- a/tfx/orchestration/experimental/core/sample_mlmd_creator.py +++ b/tfx/orchestration/experimental/core/sample_mlmd_creator.py @@ -14,13 +14,13 @@ """Creates testing MLMD with TFX data model.""" import os import tempfile +from typing import Callable, Optional -from typing import Optional, Callable from absl import app from absl import flags - from tfx.dsl.compiler import constants from tfx.orchestration import metadata +from tfx.orchestration import node_proto_view from tfx.orchestration.experimental.core import pipeline_ops from tfx.orchestration.experimental.core import pipeline_state as pstate from tfx.orchestration.experimental.core import task as task_lib @@ -69,7 +69,7 @@ def _test_pipeline(ir_path: str, pipeline_id: str, run_id: str, def _execute_nodes(handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline, version: int): """Creates fake execution of nodes.""" - for node in pstate.get_all_nodes(pipeline): + for node in node_proto_view.get_view_for_all_in(pipeline): if node.node_info.id == 'my_example_gen': test_utils.fake_example_gen_run_with_handle(handle, node, 1, version) else: diff --git a/tfx/orchestration/node_proto_view.py b/tfx/orchestration/node_proto_view.py index f2d2e76b8ff..2510280d1be 100644 --- a/tfx/orchestration/node_proto_view.py +++ b/tfx/orchestration/node_proto_view.py @@ -276,7 +276,6 @@ def get_view( raise ValueError(f'Got unknown pipeline or node type: {pipeline_or_node}.') -# TODO: b/270960179 - Migrate all usages of pipeline_state.get_all_nodes here. def get_view_for_all_in( pipeline: pipeline_pb2.Pipeline, ) -> Sequence[NodeProtoView]: