Skip to content

Commit

Permalink
Fix bug where pipeline_as_node node context is associated with the ca…
Browse files Browse the repository at this point in the history
…ched executions.

PiperOrigin-RevId: 634482397
  • Loading branch information
kmonte authored and tfx-copybara committed May 16, 2024
1 parent c4e79f0 commit 4746da8
Showing 1 changed file with 46 additions and 10 deletions.
56 changes: 46 additions & 10 deletions tfx/orchestration/portable/partial_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,24 +649,61 @@ def _get_base_pipeline_run_context(

def _get_node_context(
self, node: node_proto_view.NodeProtoView
) -> metadata_store_pb2.Context:
"""Returns node context for node."""
) -> list[metadata_store_pb2.Context]:
"""Returns node contexts for node.
For subpipelines, both the end node context and subpipeline as node context
are returned.
Args:
node: The node to get the contexts for.
Returns: The node contexts for the node.
Raises:
LookupError: If the node context is not found.
ValueError: If fetching contexts for a subpipeline with no parent pipeline
ids.
"""
contexts = []
node_id = node.node_info.id
# Return the end node context if we want to reuse a subpipeline. We do this
# because nodes dependent on a subpipeline use the subpipeline's end node
# to get their aritfacts from, so we reuse those artifacts.
if isinstance(node, node_proto_view.ComposablePipelineProtoView):
# TODO: b/340911977 - Once we only have subpipeline as node for input
# context queries, we should remove the end node context.
context_name = compiler_utils.end_node_context_name_from_subpipeline_id(
node_id
)
# Subpipelines are also considered a node in the parent pipeline, so we
# also need to add the pipeline as node context.
parent_pipeline_ids = node.raw_proto().pipeline_info.parent_ids
if not parent_pipeline_ids:
raise ValueError(
f'Subpipeline {node_id} does not have any parent pipelines.'
)
parent_pipeline_name = parent_pipeline_ids[-1]
pipeline_as_node_name = compiler_utils.node_context_name(
parent_pipeline_name, node_id
)
pipeline_as_node_context = self._node_context_by_name.get(
pipeline_as_node_name
)
if pipeline_as_node_context is None:
raise LookupError(
f'node context {pipeline_as_node_name} not found in MLMD.'
)
contexts.append(pipeline_as_node_context)
else:
context_name = compiler_utils.node_context_name(
self._pipeline_name, node_id
)
node_context = self._node_context_by_name.get(context_name)
if node_context is None:
raise LookupError(f'node context {context_name} not found in MLMD.')
return node_context
contexts.append(node_context)
return contexts

def _get_successful_executions(
self, node: node_proto_view.NodeProtoView
Expand All @@ -682,7 +719,7 @@ def _get_successful_executions(
Raises:
LookupError: If no successful Execution was found.
"""
node_context = self._get_node_context(node)
node_contexts = self._get_node_context(node)
node_id = node.node_info.id
if not self._base_run_context:
raise LookupError(
Expand All @@ -693,10 +730,9 @@ def _get_successful_executions(

all_associated_executions = (
execution_lib.get_executions_associated_with_all_contexts(
self._mlmd, contexts=[node_context, self._base_run_context]
self._mlmd, contexts=[self._base_run_context] + node_contexts
)
)

cache_only_succesful_executions = (
not node.execution_options.node_success_optional
)
Expand Down Expand Up @@ -741,15 +777,15 @@ def _cache_and_publish(
return

# Check if there are any previous attempts to cache and publish.
node_context = self._get_node_context(node)
node_contexts = self._get_node_context(node)
cached_execution_contexts = [
self._pipeline_context,
node_context,
self._new_pipeline_run_context,
]
] + node_contexts
prev_cache_executions = (
execution_lib.get_executions_associated_with_all_contexts(
self._mlmd, contexts=[node_context, self._new_pipeline_run_context]
self._mlmd,
contexts=[self._new_pipeline_run_context] + node_contexts,
)
)
if not prev_cache_executions:
Expand Down

0 comments on commit 4746da8

Please sign in to comment.