Skip to content

Commit

Permalink
Internal clean up.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653453511
  • Loading branch information
tfx-copybara committed Jul 18, 2024
1 parent 12051eb commit ad04c2d
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ def _generate_tasks_for_node(

for input_and_param in unprocessed_inputs:
if backfill_token:
assert input_and_param.exec_properties is not None
input_and_param.exec_properties[
constants.BACKFILL_TOKEN_CUSTOM_PROPERTY_KEY
] = backfill_token
Expand Down
92 changes: 49 additions & 43 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def __init__(
self._mlmd_execution_atomic_op_context = None
self._execution: Optional[metadata_store_pb2.Execution] = None
self._on_commit_callbacks: List[Callable[[], None]] = []
self._node_states_proxy: Optional[_NodeStatesProxy] = None
self._node_states_proxy: _NodeStatesProxy = None

@classmethod
@telemetry_utils.noop_telemetry(metrics_utils.no_op_metrics)
Expand Down Expand Up @@ -907,26 +907,29 @@ def _load_from_context(

@property
def execution(self) -> metadata_store_pb2.Execution:
self._check_context()
if self._execution is None:
raise RuntimeError(
'Operation must be performed within the pipeline state context.'
)
return self._execution

def is_active(self) -> bool:
"""Returns `True` if pipeline is active."""
self._check_context()
return execution_lib.is_execution_active(self._execution)
return execution_lib.is_execution_active(self.execution)

def initiate_stop(self, status: status_lib.Status) -> None:
"""Updates pipeline state to signal stopping pipeline execution."""
self._check_context()
data_types_utils.set_metadata_value(
self._execution.custom_properties[_STOP_INITIATED], 1)
self.execution.custom_properties[_STOP_INITIATED], 1
)
data_types_utils.set_metadata_value(
self._execution.custom_properties[_PIPELINE_STATUS_CODE],
int(status.code))
self.execution.custom_properties[_PIPELINE_STATUS_CODE],
int(status.code),
)
if status.message:
data_types_utils.set_metadata_value(
self._execution.custom_properties[_PIPELINE_STATUS_MSG],
status.message)
self.execution.custom_properties[_PIPELINE_STATUS_MSG], status.message
)

@_synchronized
def initiate_resume(self) -> None:
Expand Down Expand Up @@ -985,21 +988,24 @@ def _structure(

env.get_env().prepare_orchestrator_for_pipeline_run(updated_pipeline)
data_types_utils.set_metadata_value(
self._execution.custom_properties[_UPDATED_PIPELINE_IR],
_PipelineIRCodec.get().encode(updated_pipeline))
self.execution.custom_properties[_UPDATED_PIPELINE_IR],
_PipelineIRCodec.get().encode(updated_pipeline),
)
data_types_utils.set_metadata_value(
self._execution.custom_properties[_UPDATE_OPTIONS],
_base64_encode(update_options))
self.execution.custom_properties[_UPDATE_OPTIONS],
_base64_encode(update_options),
)

def is_update_initiated(self) -> bool:
self._check_context()
return self.is_active() and self._execution.custom_properties.get(
_UPDATED_PIPELINE_IR) is not None
return (
self.is_active()
and self.execution.custom_properties.get(_UPDATED_PIPELINE_IR)
is not None
)

def get_update_options(self) -> pipeline_pb2.UpdateOptions:
"""Gets pipeline update option that was previously configured."""
self._check_context()
update_options = self._execution.custom_properties.get(_UPDATE_OPTIONS)
update_options = self.execution.custom_properties.get(_UPDATE_OPTIONS)
if update_options is None:
logging.warning(
'pipeline execution missing expected custom property %s, '
Expand All @@ -1010,17 +1016,18 @@ def get_update_options(self) -> pipeline_pb2.UpdateOptions:

def apply_pipeline_update(self) -> None:
"""Applies pipeline update that was previously initiated."""
self._check_context()
updated_pipeline_ir = _get_metadata_value(
self._execution.custom_properties.get(_UPDATED_PIPELINE_IR))
self.execution.custom_properties.get(_UPDATED_PIPELINE_IR)
)
if not updated_pipeline_ir:
raise status_lib.StatusNotOkError(
code=status_lib.Code.INVALID_ARGUMENT,
message='No updated pipeline IR to apply')
data_types_utils.set_metadata_value(
self._execution.properties[_PIPELINE_IR], updated_pipeline_ir)
del self._execution.custom_properties[_UPDATED_PIPELINE_IR]
del self._execution.custom_properties[_UPDATE_OPTIONS]
self.execution.properties[_PIPELINE_IR], updated_pipeline_ir
)
del self.execution.custom_properties[_UPDATED_PIPELINE_IR]
del self.execution.custom_properties[_UPDATE_OPTIONS]
self.pipeline = _PipelineIRCodec.get().decode(updated_pipeline_ir)

def is_stop_initiated(self) -> bool:
Expand All @@ -1029,8 +1036,7 @@ def is_stop_initiated(self) -> bool:

def stop_initiated_reason(self) -> Optional[status_lib.Status]:
"""Returns status object if stop initiated, `None` otherwise."""
self._check_context()
custom_properties = self._execution.custom_properties
custom_properties = self.execution.custom_properties
if _get_metadata_value(custom_properties.get(_STOP_INITIATED)) == 1:
code = _get_metadata_value(custom_properties.get(_PIPELINE_STATUS_CODE))
if code is None:
Expand Down Expand Up @@ -1102,45 +1108,44 @@ def get_previous_node_states_dict(self) -> Dict[task_lib.NodeUid, NodeState]:

def get_pipeline_execution_state(self) -> metadata_store_pb2.Execution.State:
"""Returns state of underlying pipeline execution."""
self._check_context()
return self._execution.last_known_state
return self.execution.last_known_state

def set_pipeline_execution_state(
self, state: metadata_store_pb2.Execution.State) -> None:
"""Sets state of underlying pipeline execution."""
self._check_context()

if self._execution.last_known_state != state:
if self.execution.last_known_state != state:
self._on_commit_callbacks.append(
functools.partial(_log_pipeline_execution_state_change,
self._execution.last_known_state, state,
self.pipeline_uid))
self._execution.last_known_state = state
functools.partial(
_log_pipeline_execution_state_change,
self.execution.last_known_state,
state,
self.pipeline_uid,
)
)
self.execution.last_known_state = state

def get_property(self, property_key: str) -> Optional[types.Property]:
"""Returns custom property value from the pipeline execution."""
return _get_metadata_value(
self._execution.custom_properties.get(property_key))
self.execution.custom_properties.get(property_key)
)

def save_property(
self, property_key: str, property_value: types.Property
) -> None:
self._check_context()
data_types_utils.set_metadata_value(
self._execution.custom_properties[property_key], property_value
self.execution.custom_properties[property_key], property_value
)

def remove_property(self, property_key: str) -> None:
"""Removes a custom property of the pipeline execution if exists."""
self._check_context()
if self._execution.custom_properties.get(property_key):
del self._execution.custom_properties[property_key]
if self.execution.custom_properties.get(property_key):
del self.execution.custom_properties[property_key]

def pipeline_creation_time_secs_since_epoch(self) -> int:
"""Returns the pipeline creation time as seconds since epoch."""
self._check_context()
# Convert from milliseconds to seconds.
return self._execution.create_time_since_epoch // 1000
return self.execution.create_time_since_epoch // 1000

def get_orchestration_options(
self) -> orchestration_options.OrchestrationOptions:
Expand Down Expand Up @@ -1188,6 +1193,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self._mlmd_execution_atomic_op_context = None
self._execution = None
try:
assert mlmd_execution_atomic_op_context is not None
mlmd_execution_atomic_op_context.__exit__(exc_type, exc_val, exc_tb)
finally:
self._on_commit_callbacks.clear()
Expand Down
2 changes: 2 additions & 0 deletions tfx/orchestration/experimental/core/sync_pipeline_task_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __call__(self) -> List[task_lib.Task]:
):
successful_node_ids.add(node_id)
elif node_state.is_failure():
assert node_state.status is not None
failed_nodes_dict[node_id] = node_state.status

# Collect nodes that cannot be run because they have a failed ancestor.
Expand Down Expand Up @@ -545,6 +546,7 @@ def _generate_tasks_from_resolved_inputs(
# executions. Idempotency is guaranteed by external_id.
updated_external_artifacts = []
for input_and_params in resolved_info.input_and_params:
assert input_and_params.input_artifacts is not None
for artifacts in input_and_params.input_artifacts.values():
updated_external_artifacts.extend(
task_gen_utils.update_external_artifact_type(
Expand Down

0 comments on commit ad04c2d

Please sign in to comment.