Skip to content

Commit

Permalink
remove PAUSING and PAUSED
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 605551633
  • Loading branch information
tfx-copybara committed Feb 9, 2024
1 parent 6764b7c commit db9f112
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 23 deletions.
46 changes: 27 additions & 19 deletions tfx/orchestration/experimental/core/pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@
# cache misses).
_IN_MEMORY_PREDICATE_FN_DEFAULT_POLLING_INTERVAL_SECS = 1.0

# A special message indicating that a node is stopped by the command Update.
_STOPPED_BY_UPDATE = 'Stopped by Update command'


def _pipeline_op(lock: bool = True):
"""Decorator factory for pipeline ops."""
Expand Down Expand Up @@ -1634,7 +1637,7 @@ def _orchestrate_update_initiated_pipeline(
pipeline_state: pstate.PipelineState,
) -> None:
"""Orchestrates an update-initiated pipeline."""
nodes_to_pause = []
nodes_to_stop = []
with pipeline_state:
update_options = pipeline_state.get_update_options()
reload_node_ids = (
Expand All @@ -1654,40 +1657,42 @@ def _orchestrate_update_initiated_pipeline(
continue
node_uid = task_lib.NodeUid.from_node(pipeline, node)
with pipeline_state.node_state_update_context(node_uid) as node_state:
if node_state.is_pausable():
if node_state.is_stoppable():
node_state.update(
pstate.NodeState.PAUSING,
status_lib.Status(code=status_lib.Code.CANCELLED),
pstate.NodeState.STOPPING,
status_lib.Status(
code=status_lib.Code.CANCELLED, message=_STOPPED_BY_UPDATE
),
)
if node_state.state == pstate.NodeState.PAUSING:
nodes_to_pause.append(node)
if node_state.state == pstate.NodeState.STOPPING:
nodes_to_stop.append(node)

# Issue cancellation for nodes_to_pause and gather the ones whose pausing is
# Issue cancellation for nodes_to_stop and gather the ones whose STOPPING is
# complete.
paused_nodes = []
for node in nodes_to_pause:
stopped_nodes = []
for node in nodes_to_stop:
if _cancel_node(
mlmd_handle,
task_queue,
service_job_manager,
pipeline_state,
node,
):
paused_nodes.append(node)
stopped_nodes.append(node)

# Change the state of paused nodes to PAUSED.
# Change the state of stopped nodes to STOPPED.
with pipeline_state:
for node in paused_nodes:
for node in stopped_nodes:
node_uid = task_lib.NodeUid.from_node(pipeline, node)
with pipeline_state.node_state_update_context(node_uid) as node_state:
node_state.update(pstate.NodeState.PAUSED, node_state.status)
node_state.update(pstate.NodeState.STOPPED, node_state.status)

# If all the pausable nodes have been paused, we can update the node state to
# STARTED.
all_paused = set(n.node_info.id for n in nodes_to_pause) == set(
n.node_info.id for n in paused_nodes
# If all the stoppable nodes have been stopped, we can update the node state
# to STARTED.
all_stopped = set(n.node_info.id for n in nodes_to_stop) == set(
n.node_info.id for n in stopped_nodes
)
if all_paused:
if all_stopped:
with pipeline_state:
pipeline = pipeline_state.pipeline
for node in pstate.get_all_nodes(pipeline):
Expand All @@ -1701,7 +1706,10 @@ def _orchestrate_update_initiated_pipeline(
continue
node_uid = task_lib.NodeUid.from_node(pipeline, node)
with pipeline_state.node_state_update_context(node_uid) as node_state:
if node_state.state == pstate.NodeState.PAUSED:
if (
node_state.state == pstate.NodeState.STOPPED
and node_state.status_msg == _STOPPED_BY_UPDATE
):
node_state.update(pstate.NodeState.STARTED)

pipeline_state.apply_pipeline_update()
Expand Down
6 changes: 2 additions & 4 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ class NodeState(json_utils.Jsonable):
SKIPPED = 'skipped'
# Node execution skipped due to partial run.
SKIPPED_PARTIAL_RUN = 'skipped_partial_run'
# b/323371181 Remove pausing and paused when we are sure that no active
# pipeline has node state of pausing or paused.
PAUSING = 'pausing' # Pending work before state can change to PAUSED.
PAUSED = 'paused' # Node was paused and may be resumed in the future.
FAILED = 'failed' # Node execution failed due to errors.
Expand Down Expand Up @@ -192,10 +194,6 @@ def is_stoppable(self) -> bool:
"""Returns True if the node can be stopped."""
return self.state in set([self.STARTED, self.RUNNING, self.PAUSED])

def is_pausable(self) -> bool:
"""Returns True if the node can be stopped."""
return self.state in set([self.STARTED, self.RUNNING])

def is_backfillable(self) -> bool:
"""Returns True if the node can be backfilled."""
return self.state in set([self.STOPPED, self.FAILED])
Expand Down

0 comments on commit db9f112

Please sign in to comment.