Skip to content

Commit

Permalink
no-op
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 640684440
  • Loading branch information
XinranTang authored and tfx-copybara committed Jun 5, 2024
1 parent baab834 commit 07e82ee
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from tfx.orchestration.experimental.core import constants
from tfx.orchestration.experimental.core import env
from tfx.orchestration.experimental.core import event_observer
from tfx.orchestration.experimental.core import garbage_collection
from tfx.orchestration.experimental.core import mlmd_state
from tfx.orchestration.experimental.core import pipeline_state as pstate
from tfx.orchestration.experimental.core import service_jobs
Expand Down Expand Up @@ -2203,6 +2204,7 @@ def publish_intermediate_artifact(
mlmd_handle: metadata.Metadata,
execution_id: int,
output_key: str,
node_uid: task_lib.NodeUid,
properties: Optional[Dict[str, metadata_store_pb2.Value]],
custom_properties: Optional[Dict[str, metadata_store_pb2.Value]],
external_uri: Optional[str] = None,
Expand All @@ -2214,6 +2216,7 @@ def publish_intermediate_artifact(
mlmd_handle: A handle to the MLMD database.
execution_id: The ID of the execution which generates the artifact.
output_key: The output key of the artifact.
node_uid: The node UID of the node which generates the artifact.
properties: Properties of the artifact.
custom_properties: Custom properties of the artifact.
external_uri: The external URI provided by the user. Exactly one of
Expand Down Expand Up @@ -2318,5 +2321,13 @@ def publish_intermediate_artifact(
except mlmd_errors.StatusError as e:
raise status_lib.StatusNotOkError(code=e.error_code, message=str(e))

if node_uid:
_, node = pstate.get_pipeline_and_node(
mlmd_handle, node_uid, node_uid.pipeline_uid.pipeline_run_id
)
garbage_collection.run_garbage_collection_for_node(
mlmd_handle, node_uid, node
)

logging.info('Published intermediate artifact: %s', intermediate_artifact)
return intermediate_artifact
33 changes: 33 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1669,3 +1669,36 @@ def _get_sub_pipeline_ids_from_pipeline_info(
sub_pipeline_ids = pipeline_info.parent_ids[1:]
sub_pipeline_ids.append(pipeline_info.id)
return sub_pipeline_ids


def get_pipeline_and_node(
mlmd_handle: metadata.Metadata,
node_uid: task_lib.NodeUid,
pipeline_run_id: str,
) -> tuple[pipeline_pb2.Pipeline, node_proto_view.PipelineNodeProtoView]:
"""Gets the pipeline and node for the node_uid."""
with PipelineState.load(mlmd_handle, node_uid.pipeline_uid) as pipeline_state:
if (
pipeline_run_id or pipeline_state.pipeline_run_id
) and pipeline_run_id != pipeline_state.pipeline_run_id:
raise status_lib.StatusNotOkError(
code=status_lib.Code.NOT_FOUND,
message=(
'unable to find an active pipeline run for pipeline_run_id: '
f'{pipeline_run_id}'
),
)
nodes = get_all_nodes(pipeline_state.pipeline)
filtered_nodes = [n for n in nodes if n.node_info.id == node_uid.node_id]
if len(filtered_nodes) != 1:
raise status_lib.StatusNotOkError(
code=status_lib.Code.NOT_FOUND,
message=f'unable to find node: {node_uid}',
)
node = filtered_nodes[0]
if not isinstance(node, node_proto_view.PipelineNodeProtoView):
raise ValueError(
f'Unexpected type for node {node.node_info.id}. Only '
'pipeline nodes are supported for external executions.'
)
return (pipeline_state.pipeline, node)

0 comments on commit 07e82ee

Please sign in to comment.