Skip to content

Commit

Permalink
Internal clean up.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653453511
  • Loading branch information
tfx-copybara committed Jul 19, 2024
1 parent c0af966 commit 9b4c8dd
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ def _generate_tasks_for_node(

for input_and_param in unprocessed_inputs:
if backfill_token:
assert input_and_param.exec_properties is not None
input_and_param.exec_properties[
constants.BACKFILL_TOKEN_CUSTOM_PROPERTY_KEY
] = backfill_token
Expand Down
95 changes: 51 additions & 44 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import os
import threading
import time
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Set, Tuple
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Set, Tuple, cast
import uuid

from absl import logging
Expand Down Expand Up @@ -557,7 +557,8 @@ def __init__(
self._mlmd_execution_atomic_op_context = None
self._execution: Optional[metadata_store_pb2.Execution] = None
self._on_commit_callbacks: List[Callable[[], None]] = []
self._node_states_proxy: Optional[_NodeStatesProxy] = None
# The note state proxy is assumed to be initialized before being used.
self._node_states_proxy: _NodeStatesProxy = cast(_NodeStatesProxy, None)

@classmethod
@telemetry_utils.noop_telemetry(metrics_utils.no_op_metrics)
Expand Down Expand Up @@ -916,26 +917,29 @@ def _load_from_context(

@property
def execution(self) -> metadata_store_pb2.Execution:
self._check_context()
if self._execution is None:
raise RuntimeError(
'Operation must be performed within the pipeline state context.'
)
return self._execution

def is_active(self) -> bool:
"""Returns `True` if pipeline is active."""
self._check_context()
return execution_lib.is_execution_active(self._execution)
return execution_lib.is_execution_active(self.execution)

def initiate_stop(self, status: status_lib.Status) -> None:
"""Updates pipeline state to signal stopping pipeline execution."""
self._check_context()
data_types_utils.set_metadata_value(
self._execution.custom_properties[_STOP_INITIATED], 1)
self.execution.custom_properties[_STOP_INITIATED], 1
)
data_types_utils.set_metadata_value(
self._execution.custom_properties[_PIPELINE_STATUS_CODE],
int(status.code))
self.execution.custom_properties[_PIPELINE_STATUS_CODE],
int(status.code),
)
if status.message:
data_types_utils.set_metadata_value(
self._execution.custom_properties[_PIPELINE_STATUS_MSG],
status.message)
self.execution.custom_properties[_PIPELINE_STATUS_MSG], status.message
)

@_synchronized
def initiate_resume(self) -> None:
Expand Down Expand Up @@ -994,21 +998,24 @@ def _structure(

env.get_env().prepare_orchestrator_for_pipeline_run(updated_pipeline)
data_types_utils.set_metadata_value(
self._execution.custom_properties[_UPDATED_PIPELINE_IR],
_PipelineIRCodec.get().encode(updated_pipeline))
self.execution.custom_properties[_UPDATED_PIPELINE_IR],
_PipelineIRCodec.get().encode(updated_pipeline),
)
data_types_utils.set_metadata_value(
self._execution.custom_properties[_UPDATE_OPTIONS],
_base64_encode(update_options))
self.execution.custom_properties[_UPDATE_OPTIONS],
_base64_encode(update_options),
)

def is_update_initiated(self) -> bool:
self._check_context()
return self.is_active() and self._execution.custom_properties.get(
_UPDATED_PIPELINE_IR) is not None
return (
self.is_active()
and self.execution.custom_properties.get(_UPDATED_PIPELINE_IR)
is not None
)

def get_update_options(self) -> pipeline_pb2.UpdateOptions:
"""Gets pipeline update option that was previously configured."""
self._check_context()
update_options = self._execution.custom_properties.get(_UPDATE_OPTIONS)
update_options = self.execution.custom_properties.get(_UPDATE_OPTIONS)
if update_options is None:
logging.warning(
'pipeline execution missing expected custom property %s, '
Expand All @@ -1019,17 +1026,18 @@ def get_update_options(self) -> pipeline_pb2.UpdateOptions:

def apply_pipeline_update(self) -> None:
"""Applies pipeline update that was previously initiated."""
self._check_context()
updated_pipeline_ir = _get_metadata_value(
self._execution.custom_properties.get(_UPDATED_PIPELINE_IR))
self.execution.custom_properties.get(_UPDATED_PIPELINE_IR)
)
if not updated_pipeline_ir:
raise status_lib.StatusNotOkError(
code=status_lib.Code.INVALID_ARGUMENT,
message='No updated pipeline IR to apply')
data_types_utils.set_metadata_value(
self._execution.properties[_PIPELINE_IR], updated_pipeline_ir)
del self._execution.custom_properties[_UPDATED_PIPELINE_IR]
del self._execution.custom_properties[_UPDATE_OPTIONS]
self.execution.properties[_PIPELINE_IR], updated_pipeline_ir
)
del self.execution.custom_properties[_UPDATED_PIPELINE_IR]
del self.execution.custom_properties[_UPDATE_OPTIONS]
self.pipeline = _PipelineIRCodec.get().decode(updated_pipeline_ir)

def is_stop_initiated(self) -> bool:
Expand All @@ -1038,8 +1046,7 @@ def is_stop_initiated(self) -> bool:

def stop_initiated_reason(self) -> Optional[status_lib.Status]:
"""Returns status object if stop initiated, `None` otherwise."""
self._check_context()
custom_properties = self._execution.custom_properties
custom_properties = self.execution.custom_properties
if _get_metadata_value(custom_properties.get(_STOP_INITIATED)) == 1:
code = _get_metadata_value(custom_properties.get(_PIPELINE_STATUS_CODE))
if code is None:
Expand Down Expand Up @@ -1111,45 +1118,44 @@ def get_previous_node_states_dict(self) -> Dict[task_lib.NodeUid, NodeState]:

def get_pipeline_execution_state(self) -> metadata_store_pb2.Execution.State:
"""Returns state of underlying pipeline execution."""
self._check_context()
return self._execution.last_known_state
return self.execution.last_known_state

def set_pipeline_execution_state(
self, state: metadata_store_pb2.Execution.State) -> None:
"""Sets state of underlying pipeline execution."""
self._check_context()

if self._execution.last_known_state != state:
if self.execution.last_known_state != state:
self._on_commit_callbacks.append(
functools.partial(_log_pipeline_execution_state_change,
self._execution.last_known_state, state,
self.pipeline_uid))
self._execution.last_known_state = state
functools.partial(
_log_pipeline_execution_state_change,
self.execution.last_known_state,
state,
self.pipeline_uid,
)
)
self.execution.last_known_state = state

def get_property(self, property_key: str) -> Optional[types.Property]:
"""Returns custom property value from the pipeline execution."""
return _get_metadata_value(
self._execution.custom_properties.get(property_key))
self.execution.custom_properties.get(property_key)
)

def save_property(
self, property_key: str, property_value: types.Property
) -> None:
self._check_context()
data_types_utils.set_metadata_value(
self._execution.custom_properties[property_key], property_value
self.execution.custom_properties[property_key], property_value
)

def remove_property(self, property_key: str) -> None:
"""Removes a custom property of the pipeline execution if exists."""
self._check_context()
if self._execution.custom_properties.get(property_key):
del self._execution.custom_properties[property_key]
if self.execution.custom_properties.get(property_key):
del self.execution.custom_properties[property_key]

def pipeline_creation_time_secs_since_epoch(self) -> int:
"""Returns the pipeline creation time as seconds since epoch."""
self._check_context()
# Convert from milliseconds to seconds.
return self._execution.create_time_since_epoch // 1000
return self.execution.create_time_since_epoch // 1000

def get_orchestration_options(
self) -> orchestration_options.OrchestrationOptions:
Expand Down Expand Up @@ -1197,6 +1203,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self._mlmd_execution_atomic_op_context = None
self._execution = None
try:
assert mlmd_execution_atomic_op_context is not None
mlmd_execution_atomic_op_context.__exit__(exc_type, exc_val, exc_tb)
finally:
self._on_commit_callbacks.clear()
Expand Down
2 changes: 2 additions & 0 deletions tfx/orchestration/experimental/core/sync_pipeline_task_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __call__(self) -> List[task_lib.Task]:
):
successful_node_ids.add(node_id)
elif node_state.is_failure():
assert node_state.status is not None
failed_nodes_dict[node_id] = node_state.status

# Collect nodes that cannot be run because they have a failed ancestor.
Expand Down Expand Up @@ -545,6 +546,7 @@ def _generate_tasks_from_resolved_inputs(
# executions. Idempotency is guaranteed by external_id.
updated_external_artifacts = []
for input_and_params in resolved_info.input_and_params:
assert input_and_params.input_artifacts is not None
for artifacts in input_and_params.input_artifacts.values():
updated_external_artifacts.extend(
task_gen_utils.update_external_artifact_type(
Expand Down

0 comments on commit 9b4c8dd

Please sign in to comment.