Skip to content

Commit

Permalink
Fully swap to using node_proto_view.get_view_for_all_in
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 643069039
  • Loading branch information
kmonte authored and tfx-copybara committed Jun 14, 2024
1 parent 74ba85e commit cb7628d
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 44 deletions.
21 changes: 11 additions & 10 deletions tfx/orchestration/experimental/core/pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def _check_nodes_exist(
) -> None:
"""Raises an error if node_uid does not exist in the pipeline."""
node_id_set = set(n.node_id for n in node_uids)
nodes = pstate.get_all_nodes(pipeline)
nodes = node_proto_view.get_view_for_all_in(pipeline)
filtered_nodes = [n for n in nodes if n.node_info.id in node_id_set]
if len(filtered_nodes) != len(node_id_set):
raise status_lib.StatusNotOkError(
Expand Down Expand Up @@ -570,7 +570,7 @@ def resume_manual_node(
mlmd_handle, node_uid.pipeline_uid
) as pipeline_state:
env.get_env().check_if_can_orchestrate(pipeline_state.pipeline)
nodes = pstate.get_all_nodes(pipeline_state.pipeline)
nodes = node_proto_view.get_view_for_all_in(pipeline_state.pipeline)
filtered_nodes = [n for n in nodes if n.node_info.id == node_uid.node_id]
if len(filtered_nodes) != 1:
raise status_lib.StatusNotOkError(
Expand Down Expand Up @@ -959,7 +959,8 @@ def resume_pipeline(
if node_state.is_success():
previously_succeeded_nodes.append(node)
pipeline_nodes = [
node.node_info.id for node in pstate.get_all_nodes(pipeline)
node.node_info.id
for node in node_proto_view.get_view_for_all_in(pipeline)
]

# Mark nodes using partial pipeline run lib.
Expand Down Expand Up @@ -1005,7 +1006,7 @@ def _recursively_revive_pipelines(
) -> pstate.PipelineState:
"""Recursively revives all pipelines, resuing executions if present."""
with pipeline_state:
nodes = pstate.get_all_nodes(pipeline_state.pipeline)
nodes = node_proto_view.get_view_for_all_in(pipeline_state.pipeline)
node_by_name = {node.node_info.id: node for node in nodes}
# TODO(b/272015049): Add support for manager start nodes.
nodes_to_start = [
Expand Down Expand Up @@ -1510,7 +1511,7 @@ def _run_end_nodes(
# Build some dicts and find all paired nodes
end_nodes = []
pipeline = pipeline_state.pipeline
nodes = pstate.get_all_nodes(pipeline)
nodes = node_proto_view.get_view_for_all_in(pipeline)
node_uid_by_id = {}
with pipeline_state:
node_state_by_node_uid = pipeline_state.get_node_states_dict()
Expand Down Expand Up @@ -1626,7 +1627,7 @@ def _orchestrate_stop_initiated_pipeline(
pipeline = pipeline_state.pipeline
stop_reason = pipeline_state.stop_initiated_reason()
assert stop_reason is not None
for node in pstate.get_all_nodes(pipeline):
for node in node_proto_view.get_view_for_all_in(pipeline):
node_uid = task_lib.NodeUid.from_node(pipeline, node)
with pipeline_state.node_state_update_context(node_uid) as node_state:
if node_state.is_stoppable():
Expand Down Expand Up @@ -1683,7 +1684,7 @@ def _orchestrate_stop_initiated_pipeline(
)
if any(
n.execution_options.HasField('resource_lifetime')
for n in pstate.get_all_nodes(pipeline_state.pipeline)
for n in node_proto_view.get_view_for_all_in(pipeline_state.pipeline)
):
logging.info('Pipeline has paired nodes. May launch additional jobs')
# Note that this is a pretty hacky "best effort" attempt at cleanup, we
Expand Down Expand Up @@ -1725,7 +1726,7 @@ def _orchestrate_update_initiated_pipeline(
else None
)
pipeline = pipeline_state.pipeline
for node in pstate.get_all_nodes(pipeline):
for node in node_proto_view.get_view_for_all_in(pipeline):
# TODO(b/217584342): Partial reload which excludes service nodes is not
# fully supported in async pipelines since we don't have a mechanism to
# reload them later for new executions.
Expand Down Expand Up @@ -1774,7 +1775,7 @@ def _orchestrate_update_initiated_pipeline(
if all_stopped:
with pipeline_state:
pipeline = pipeline_state.pipeline
for node in pstate.get_all_nodes(pipeline):
for node in node_proto_view.get_view_for_all_in(pipeline):
# TODO(b/217584342): Partial reload which excludes service nodes is not
# fully supported in async pipelines since we don't have a mechanism to
# reload them later for new executions.
Expand Down Expand Up @@ -2001,7 +2002,7 @@ def _filter_by_node_id(

def _get_node_infos(pipeline_state: pstate.PipelineState) -> List[_NodeInfo]:
"""Returns a list of `_NodeInfo` object for each node in the pipeline."""
nodes = pstate.get_all_nodes(pipeline_state.pipeline)
nodes = node_proto_view.get_view_for_all_in(pipeline_state.pipeline)
result: List[_NodeInfo] = []
with pipeline_state:
for node in nodes:
Expand Down
49 changes: 19 additions & 30 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from tfx.proto.orchestration import metadata_pb2
from tfx.proto.orchestration import pipeline_pb2
from tfx.proto.orchestration import run_state_pb2
from tfx.utils import deprecation_utils
from tfx.utils import json_utils
from tfx.utils import status as status_lib

Expand Down Expand Up @@ -969,8 +968,14 @@ def initiate_update(
def _structure(
pipeline: pipeline_pb2.Pipeline
) -> List[Tuple[str, List[str], List[str]]]:
return [(node.node_info.id, list(node.upstream_nodes),
list(node.downstream_nodes)) for node in get_all_nodes(pipeline)]
return [
(
node.node_info.id,
list(node.upstream_nodes),
list(node.downstream_nodes),
)
for node in node_proto_view.get_view_for_all_in(pipeline)
]

if _structure(self.pipeline) != _structure(updated_pipeline):
raise status_lib.StatusNotOkError(
Expand Down Expand Up @@ -1078,7 +1083,7 @@ def get_node_states_dict(self) -> Dict[task_lib.NodeUid, NodeState]:
self._check_context()
node_states_dict = self._node_states_proxy.get()
result = {}
for node in get_all_nodes(self.pipeline):
for node in node_proto_view.get_view_for_all_in(self.pipeline):
node_uid = task_lib.NodeUid.from_node(self.pipeline, node)
result[node_uid] = node_states_dict.get(node_uid.node_id, NodeState())
return result
Expand All @@ -1088,7 +1093,7 @@ def get_previous_node_states_dict(self) -> Dict[task_lib.NodeUid, NodeState]:
self._check_context()
node_states_dict = self._node_states_proxy.get(_PREVIOUS_NODE_STATES)
result = {}
for node in get_all_nodes(self.pipeline):
for node in node_proto_view.get_view_for_all_in(self.pipeline):
node_uid = task_lib.NodeUid.from_node(self.pipeline, node)
if node_uid.node_id not in node_states_dict:
continue
Expand Down Expand Up @@ -1363,7 +1368,7 @@ def get_node_run_states(self) -> Dict[str, run_state_pb2.RunState]:
"""Returns a dict mapping node id to current run state."""
result = {}
node_states_dict = self._node_states_proxy.get()
for node in get_all_nodes(self.pipeline):
for node in node_proto_view.get_view_for_all_in(self.pipeline):
node_state = node_states_dict.get(node.node_info.id, NodeState())
result[node.node_info.id] = node_state.to_run_state()
return result
Expand All @@ -1373,7 +1378,7 @@ def get_node_run_states_history(
"""Returns the history of node run states and timestamps."""
node_states_dict = self._node_states_proxy.get()
result = {}
for node in get_all_nodes(self.pipeline):
for node in node_proto_view.get_view_for_all_in(self.pipeline):
node_state = node_states_dict.get(node.node_info.id, NodeState())
result[node.node_info.id] = node_state.to_run_state_history()
return result
Expand All @@ -1382,7 +1387,7 @@ def get_previous_node_run_states(self) -> Dict[str, run_state_pb2.RunState]:
"""Returns a dict mapping node id to previous run state."""
result = {}
node_states_dict = self._node_states_proxy.get(_PREVIOUS_NODE_STATES)
for node in get_all_nodes(self.pipeline):
for node in node_proto_view.get_view_for_all_in(self.pipeline):
if node.node_info.id not in node_states_dict:
continue
node_state = node_states_dict[node.node_info.id]
Expand All @@ -1394,7 +1399,7 @@ def get_previous_node_run_states_history(
"""Returns a dict mapping node id to previous run state and timestamps."""
prev_node_states_dict = self._node_states_proxy.get(_PREVIOUS_NODE_STATES)
result = {}
for node in get_all_nodes(self.pipeline):
for node in node_proto_view.get_view_for_all_in(self.pipeline):
if node.node_info.id not in prev_node_states_dict:
continue
node_state = prev_node_states_dict[node.node_info.id]
Expand All @@ -1410,7 +1415,7 @@ def get_node_states_dict(self) -> Dict[str, NodeState]:
"""Returns a dict mapping node id to node state."""
result = {}
node_states_dict = self._node_states_proxy.get()
for node in get_all_nodes(self.pipeline):
for node in node_proto_view.get_view_for_all_in(self.pipeline):
result[node.node_info.id] = node_states_dict.get(node.node_info.id,
NodeState())
return result
Expand All @@ -1419,7 +1424,7 @@ def get_previous_node_states_dict(self) -> Dict[str, NodeState]:
"""Returns a dict mapping node id to node state in previous run."""
result = {}
node_states_dict = self._node_states_proxy.get(_PREVIOUS_NODE_STATES)
for node in get_all_nodes(self.pipeline):
for node in node_proto_view.get_view_for_all_in(self.pipeline):
if node.node_info.id not in node_states_dict:
continue
result[node.node_info.id] = node_states_dict[node.node_info.id]
Expand All @@ -1439,22 +1444,6 @@ def pipeline_id_from_orchestrator_context(
return context.name


@deprecation_utils.deprecated(
None,
'pipeline_state.get_all_nodes has been deprecated in favor of'
' node_proto_view.get_view_for_all_in which has identical behavior.',
)
@telemetry_utils.noop_telemetry(metrics_utils.no_op_metrics)
def get_all_nodes(
pipeline: pipeline_pb2.Pipeline) -> List[node_proto_view.NodeProtoView]:
"""Returns the views of nodes or inner pipelines in the given pipeline."""
# TODO(goutham): Handle system nodes.
return [
node_proto_view.get_view(pipeline_or_node)
for pipeline_or_node in pipeline.nodes
]


@telemetry_utils.noop_telemetry(metrics_utils.no_op_metrics)
def get_all_node_executions(
pipeline: pipeline_pb2.Pipeline,
Expand Down Expand Up @@ -1484,7 +1473,7 @@ def get_all_node_executions(
node.node_info.id: task_gen_utils.get_executions(
mlmd_handle, node, additional_filters=additional_filters
)
for node in get_all_nodes(pipeline)
for node in node_proto_view.get_view_for_all_in(pipeline)
}


Expand Down Expand Up @@ -1528,7 +1517,7 @@ def get_all_node_artifacts(
def _is_node_uid_in_pipeline(node_uid: task_lib.NodeUid,
pipeline: pipeline_pb2.Pipeline) -> bool:
"""Returns `True` if the `node_uid` belongs to the given pipeline."""
for node in get_all_nodes(pipeline):
for node in node_proto_view.get_view_for_all_in(pipeline):
if task_lib.NodeUid.from_node(pipeline, node) == node_uid:
return True
return False
Expand Down Expand Up @@ -1593,7 +1582,7 @@ def _save_skipped_node_states(pipeline: pipeline_pb2.Pipeline,
if reused_pipeline_view
else {}
)
for node in get_all_nodes(pipeline):
for node in node_proto_view.get_view_for_all_in(pipeline):
node_id = node.node_info.id
if node.execution_options.HasField('skip'):
logging.info('Node %s is skipped in this partial run.', node_id)
Expand Down
6 changes: 3 additions & 3 deletions tfx/orchestration/experimental/core/sample_mlmd_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
"""Creates testing MLMD with TFX data model."""
import os
import tempfile
from typing import Callable, Optional

from typing import Optional, Callable
from absl import app
from absl import flags

from tfx.dsl.compiler import constants
from tfx.orchestration import metadata
from tfx.orchestration import node_proto_view
from tfx.orchestration.experimental.core import pipeline_ops
from tfx.orchestration.experimental.core import pipeline_state as pstate
from tfx.orchestration.experimental.core import task as task_lib
Expand Down Expand Up @@ -69,7 +69,7 @@ def _test_pipeline(ir_path: str, pipeline_id: str, run_id: str,
def _execute_nodes(handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline,
version: int):
"""Creates fake execution of nodes."""
for node in pstate.get_all_nodes(pipeline):
for node in node_proto_view.get_view_for_all_in(pipeline):
if node.node_info.id == 'my_example_gen':
test_utils.fake_example_gen_run_with_handle(handle, node, 1, version)
else:
Expand Down
1 change: 0 additions & 1 deletion tfx/orchestration/node_proto_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ def get_view(
raise ValueError(f'Got unknown pipeline or node type: {pipeline_or_node}.')


# TODO: b/270960179 - Migrate all usages of pipeline_state.get_all_nodes here.
def get_view_for_all_in(
pipeline: pipeline_pb2.Pipeline,
) -> Sequence[NodeProtoView]:
Expand Down

0 comments on commit cb7628d

Please sign in to comment.