diff --git a/tfx/orchestration/experimental/core/pipeline_ops.py b/tfx/orchestration/experimental/core/pipeline_ops.py index 8c07f609777..8ca6ce91fea 100644 --- a/tfx/orchestration/experimental/core/pipeline_ops.py +++ b/tfx/orchestration/experimental/core/pipeline_ops.py @@ -37,6 +37,7 @@ from tfx.orchestration.experimental.core import constants from tfx.orchestration.experimental.core import env from tfx.orchestration.experimental.core import event_observer +from tfx.orchestration.experimental.core import garbage_collection from tfx.orchestration.experimental.core import mlmd_state from tfx.orchestration.experimental.core import pipeline_state as pstate from tfx.orchestration.experimental.core import service_jobs @@ -2203,6 +2204,7 @@ def publish_intermediate_artifact( mlmd_handle: metadata.Metadata, execution_id: int, output_key: str, + node_uid: task_lib.NodeUid, properties: Optional[Dict[str, metadata_store_pb2.Value]], custom_properties: Optional[Dict[str, metadata_store_pb2.Value]], external_uri: Optional[str] = None, @@ -2214,6 +2216,7 @@ def publish_intermediate_artifact( mlmd_handle: A handle to the MLMD database. execution_id: The ID of the execution which generates the artifact. output_key: The output key of the artifact. + node_uid: The node UID of the node which generates the artifact. properties: Properties of the artifact. custom_properties: Custom properties of the artifact. external_uri: The external URI provided by the user. Exactly one of @@ -2318,5 +2321,13 @@ def publish_intermediate_artifact( except mlmd_errors.StatusError as e: raise status_lib.StatusNotOkError(code=e.error_code, message=str(e)) + if node_uid: + _, node = pstate.get_pipeline_and_node( + mlmd_handle, node_uid, node_uid.pipeline_uid.pipeline_run_id + ) + garbage_collection.run_garbage_collection_for_node( + mlmd_handle, node_uid, node + ) + logging.info('Published intermediate artifact: %s', intermediate_artifact) return intermediate_artifact diff --git a/tfx/orchestration/experimental/core/pipeline_state.py b/tfx/orchestration/experimental/core/pipeline_state.py index 32139c5e629..688f3fe1632 100644 --- a/tfx/orchestration/experimental/core/pipeline_state.py +++ b/tfx/orchestration/experimental/core/pipeline_state.py @@ -1669,3 +1669,36 @@ def _get_sub_pipeline_ids_from_pipeline_info( sub_pipeline_ids = pipeline_info.parent_ids[1:] sub_pipeline_ids.append(pipeline_info.id) return sub_pipeline_ids + + +def get_pipeline_and_node( + mlmd_handle: metadata.Metadata, + node_uid: task_lib.NodeUid, + pipeline_run_id: str, +) -> tuple[pipeline_pb2.Pipeline, node_proto_view.PipelineNodeProtoView]: + """Gets the pipeline and node for the node_uid.""" + with PipelineState.load(mlmd_handle, node_uid.pipeline_uid) as pipeline_state: + if ( + pipeline_run_id or pipeline_state.pipeline_run_id + ) and pipeline_run_id != pipeline_state.pipeline_run_id: + raise status_lib.StatusNotOkError( + code=status_lib.Code.NOT_FOUND, + message=( + 'unable to find an active pipeline run for pipeline_run_id: ' + f'{pipeline_run_id}' + ), + ) + nodes = get_all_nodes(pipeline_state.pipeline) + filtered_nodes = [n for n in nodes if n.node_info.id == node_uid.node_id] + if len(filtered_nodes) != 1: + raise status_lib.StatusNotOkError( + code=status_lib.Code.NOT_FOUND, + message=f'unable to find node: {node_uid}', + ) + node = filtered_nodes[0] + if not isinstance(node, node_proto_view.PipelineNodeProtoView): + raise ValueError( + f'Unexpected type for node {node.node_info.id}. Only ' + 'pipeline nodes are supported for external executions.' + ) + return (pipeline_state.pipeline, node)