From 9b4c8dd4df9e893052dd7f6c76156afac5592238 Mon Sep 17 00:00:00 2001 From: tfx-team Date: Wed, 17 Jul 2024 20:22:23 -0700 Subject: [PATCH] Internal clean up. PiperOrigin-RevId: 653453511 --- .../core/async_pipeline_task_gen.py | 1 + .../experimental/core/pipeline_state.py | 95 ++++++++++--------- .../core/sync_pipeline_task_gen.py | 2 + 3 files changed, 54 insertions(+), 44 deletions(-) diff --git a/tfx/orchestration/experimental/core/async_pipeline_task_gen.py b/tfx/orchestration/experimental/core/async_pipeline_task_gen.py index 416a03cf65..60a36b773b 100644 --- a/tfx/orchestration/experimental/core/async_pipeline_task_gen.py +++ b/tfx/orchestration/experimental/core/async_pipeline_task_gen.py @@ -477,6 +477,7 @@ def _generate_tasks_for_node( for input_and_param in unprocessed_inputs: if backfill_token: + assert input_and_param.exec_properties is not None input_and_param.exec_properties[ constants.BACKFILL_TOKEN_CUSTOM_PROPERTY_KEY ] = backfill_token diff --git a/tfx/orchestration/experimental/core/pipeline_state.py b/tfx/orchestration/experimental/core/pipeline_state.py index 76559e8391..8c7338ce43 100644 --- a/tfx/orchestration/experimental/core/pipeline_state.py +++ b/tfx/orchestration/experimental/core/pipeline_state.py @@ -22,7 +22,7 @@ import os import threading import time -from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Set, Tuple +from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Set, Tuple, cast import uuid from absl import logging @@ -557,7 +557,8 @@ def __init__( self._mlmd_execution_atomic_op_context = None self._execution: Optional[metadata_store_pb2.Execution] = None self._on_commit_callbacks: List[Callable[[], None]] = [] - self._node_states_proxy: Optional[_NodeStatesProxy] = None + # The note state proxy is assumed to be initialized before being used. + self._node_states_proxy: _NodeStatesProxy = cast(_NodeStatesProxy, None) @classmethod @telemetry_utils.noop_telemetry(metrics_utils.no_op_metrics) @@ -916,26 +917,29 @@ def _load_from_context( @property def execution(self) -> metadata_store_pb2.Execution: - self._check_context() + if self._execution is None: + raise RuntimeError( + 'Operation must be performed within the pipeline state context.' + ) return self._execution def is_active(self) -> bool: """Returns `True` if pipeline is active.""" - self._check_context() - return execution_lib.is_execution_active(self._execution) + return execution_lib.is_execution_active(self.execution) def initiate_stop(self, status: status_lib.Status) -> None: """Updates pipeline state to signal stopping pipeline execution.""" - self._check_context() data_types_utils.set_metadata_value( - self._execution.custom_properties[_STOP_INITIATED], 1) + self.execution.custom_properties[_STOP_INITIATED], 1 + ) data_types_utils.set_metadata_value( - self._execution.custom_properties[_PIPELINE_STATUS_CODE], - int(status.code)) + self.execution.custom_properties[_PIPELINE_STATUS_CODE], + int(status.code), + ) if status.message: data_types_utils.set_metadata_value( - self._execution.custom_properties[_PIPELINE_STATUS_MSG], - status.message) + self.execution.custom_properties[_PIPELINE_STATUS_MSG], status.message + ) @_synchronized def initiate_resume(self) -> None: @@ -994,21 +998,24 @@ def _structure( env.get_env().prepare_orchestrator_for_pipeline_run(updated_pipeline) data_types_utils.set_metadata_value( - self._execution.custom_properties[_UPDATED_PIPELINE_IR], - _PipelineIRCodec.get().encode(updated_pipeline)) + self.execution.custom_properties[_UPDATED_PIPELINE_IR], + _PipelineIRCodec.get().encode(updated_pipeline), + ) data_types_utils.set_metadata_value( - self._execution.custom_properties[_UPDATE_OPTIONS], - _base64_encode(update_options)) + self.execution.custom_properties[_UPDATE_OPTIONS], + _base64_encode(update_options), + ) def is_update_initiated(self) -> bool: - self._check_context() - return self.is_active() and self._execution.custom_properties.get( - _UPDATED_PIPELINE_IR) is not None + return ( + self.is_active() + and self.execution.custom_properties.get(_UPDATED_PIPELINE_IR) + is not None + ) def get_update_options(self) -> pipeline_pb2.UpdateOptions: """Gets pipeline update option that was previously configured.""" - self._check_context() - update_options = self._execution.custom_properties.get(_UPDATE_OPTIONS) + update_options = self.execution.custom_properties.get(_UPDATE_OPTIONS) if update_options is None: logging.warning( 'pipeline execution missing expected custom property %s, ' @@ -1019,17 +1026,18 @@ def get_update_options(self) -> pipeline_pb2.UpdateOptions: def apply_pipeline_update(self) -> None: """Applies pipeline update that was previously initiated.""" - self._check_context() updated_pipeline_ir = _get_metadata_value( - self._execution.custom_properties.get(_UPDATED_PIPELINE_IR)) + self.execution.custom_properties.get(_UPDATED_PIPELINE_IR) + ) if not updated_pipeline_ir: raise status_lib.StatusNotOkError( code=status_lib.Code.INVALID_ARGUMENT, message='No updated pipeline IR to apply') data_types_utils.set_metadata_value( - self._execution.properties[_PIPELINE_IR], updated_pipeline_ir) - del self._execution.custom_properties[_UPDATED_PIPELINE_IR] - del self._execution.custom_properties[_UPDATE_OPTIONS] + self.execution.properties[_PIPELINE_IR], updated_pipeline_ir + ) + del self.execution.custom_properties[_UPDATED_PIPELINE_IR] + del self.execution.custom_properties[_UPDATE_OPTIONS] self.pipeline = _PipelineIRCodec.get().decode(updated_pipeline_ir) def is_stop_initiated(self) -> bool: @@ -1038,8 +1046,7 @@ def is_stop_initiated(self) -> bool: def stop_initiated_reason(self) -> Optional[status_lib.Status]: """Returns status object if stop initiated, `None` otherwise.""" - self._check_context() - custom_properties = self._execution.custom_properties + custom_properties = self.execution.custom_properties if _get_metadata_value(custom_properties.get(_STOP_INITIATED)) == 1: code = _get_metadata_value(custom_properties.get(_PIPELINE_STATUS_CODE)) if code is None: @@ -1111,45 +1118,44 @@ def get_previous_node_states_dict(self) -> Dict[task_lib.NodeUid, NodeState]: def get_pipeline_execution_state(self) -> metadata_store_pb2.Execution.State: """Returns state of underlying pipeline execution.""" - self._check_context() - return self._execution.last_known_state + return self.execution.last_known_state def set_pipeline_execution_state( self, state: metadata_store_pb2.Execution.State) -> None: """Sets state of underlying pipeline execution.""" - self._check_context() - - if self._execution.last_known_state != state: + if self.execution.last_known_state != state: self._on_commit_callbacks.append( - functools.partial(_log_pipeline_execution_state_change, - self._execution.last_known_state, state, - self.pipeline_uid)) - self._execution.last_known_state = state + functools.partial( + _log_pipeline_execution_state_change, + self.execution.last_known_state, + state, + self.pipeline_uid, + ) + ) + self.execution.last_known_state = state def get_property(self, property_key: str) -> Optional[types.Property]: """Returns custom property value from the pipeline execution.""" return _get_metadata_value( - self._execution.custom_properties.get(property_key)) + self.execution.custom_properties.get(property_key) + ) def save_property( self, property_key: str, property_value: types.Property ) -> None: - self._check_context() data_types_utils.set_metadata_value( - self._execution.custom_properties[property_key], property_value + self.execution.custom_properties[property_key], property_value ) def remove_property(self, property_key: str) -> None: """Removes a custom property of the pipeline execution if exists.""" - self._check_context() - if self._execution.custom_properties.get(property_key): - del self._execution.custom_properties[property_key] + if self.execution.custom_properties.get(property_key): + del self.execution.custom_properties[property_key] def pipeline_creation_time_secs_since_epoch(self) -> int: """Returns the pipeline creation time as seconds since epoch.""" - self._check_context() # Convert from milliseconds to seconds. - return self._execution.create_time_since_epoch // 1000 + return self.execution.create_time_since_epoch // 1000 def get_orchestration_options( self) -> orchestration_options.OrchestrationOptions: @@ -1197,6 +1203,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._mlmd_execution_atomic_op_context = None self._execution = None try: + assert mlmd_execution_atomic_op_context is not None mlmd_execution_atomic_op_context.__exit__(exc_type, exc_val, exc_tb) finally: self._on_commit_callbacks.clear() diff --git a/tfx/orchestration/experimental/core/sync_pipeline_task_gen.py b/tfx/orchestration/experimental/core/sync_pipeline_task_gen.py index 8726256b96..04f49cdeca 100644 --- a/tfx/orchestration/experimental/core/sync_pipeline_task_gen.py +++ b/tfx/orchestration/experimental/core/sync_pipeline_task_gen.py @@ -169,6 +169,7 @@ def __call__(self) -> List[task_lib.Task]: ): successful_node_ids.add(node_id) elif node_state.is_failure(): + assert node_state.status is not None failed_nodes_dict[node_id] = node_state.status # Collect nodes that cannot be run because they have a failed ancestor. @@ -545,6 +546,7 @@ def _generate_tasks_from_resolved_inputs( # executions. Idempotency is guaranteed by external_id. updated_external_artifacts = [] for input_and_params in resolved_info.input_and_params: + assert input_and_params.input_artifacts is not None for artifacts in input_and_params.input_artifacts.values(): updated_external_artifacts.extend( task_gen_utils.update_external_artifact_type(