From 1692daa393f130388bf9ff8ac70205ba2d04fa7c Mon Sep 17 00:00:00 2001 From: tfx-team Date: Wed, 12 Jun 2024 14:31:23 -0700 Subject: [PATCH] implement function get_upstream_artifacts_by_artifacts PiperOrigin-RevId: 642738967 --- .../mlmd_resolver/metadata_resolver.py | 273 +++++++++++++----- .../mlmd_resolver/metadata_resolver_test.py | 257 ++++++++++++++++- 2 files changed, 456 insertions(+), 74 deletions(-) diff --git a/tfx/orchestration/portable/input_resolution/mlmd_resolver/metadata_resolver.py b/tfx/orchestration/portable/input_resolution/mlmd_resolver/metadata_resolver.py index 2aa52031d9e..553e8ec86fe 100644 --- a/tfx/orchestration/portable/input_resolution/mlmd_resolver/metadata_resolver.py +++ b/tfx/orchestration/portable/input_resolution/mlmd_resolver/metadata_resolver.py @@ -64,10 +64,104 @@ def __init__( self._store = store self._mlmd_connection_manager = mlmd_connection_manager - # TODO(b/302730333) Write a function get_upstream_artifacts_by_artifacts(), - # which is similar to get_downstream_artifacts_by_artifacts(). + def _get_external_upstream_or_downstream_artifacts( + self, + external_artifact_ids: List[str], + max_num_hops: int = _MAX_NUM_HOPS, + filter_query: str = '', + event_filter: Optional[Callable[[metadata_store_pb2.Event], bool]] = None, + downstream: bool = True, + ): + """Gets downstream or upstream artifacts from external artifact ids. + + Args: + external_artifact_ids: A list of external artifact ids. + max_num_hops: maximum number of hops performed for tracing. `max_num_hops` + cannot exceed 100 nor be negative. + filter_query: a query string filtering artifacts by their own attributes + or the attributes of immediate neighbors. Please refer to + go/mlmd-filter-query-guide for more detailed guidance. Note: if + `filter_query` is specified and `max_num_hops` is 0, it's equivalent to + getting filtered artifacts by artifact ids with `get_artifacts()`. + event_filter: an optional callable object for filtering events in the + paths towards the artifacts. Only an event with `event_filter(event)` + evaluated to True will be considered as valid and kept in the path. + downstream: If true, get downstream artifacts. Otherwise, get upstream + artifacts. + + Returns: + Mapping of artifact ids to a list of downstream or upstream artifacts. + + Raises: + ValueError: If mlmd_connection_manager is not initialized. + """ + if not self._mlmd_connection_manager: + raise ValueError( + 'mlmd_connection_manager is not initialized. There are external' + 'artifacts, so we need it to query the external MLMD instance.' + ) + + store_by_pipeline_asset: Dict[str, mlmd.MetadataStore] = {} + external_ids_by_pipeline_asset: Dict[str, List[str]] = ( + collections.defaultdict(list) + ) + for external_id in external_artifact_ids: + connection_config = ( + external_artifact_utils.get_external_connection_config(external_id) + ) + store = self._mlmd_connection_manager.get_mlmd_handle( + connection_config + ).store + pipeline_asset = ( + external_artifact_utils.get_pipeline_asset_from_external_id( + external_id + ) + ) + external_ids_by_pipeline_asset[pipeline_asset].append(external_id) + store_by_pipeline_asset[pipeline_asset] = store - # TODO(b/302730333) Write unit tests for the new functions. + result = {} + # Gets artifacts from each external store. + for pipeline_asset, external_ids in external_ids_by_pipeline_asset.items(): + store = store_by_pipeline_asset[pipeline_asset] + external_id_by_id = { + external_artifact_utils.get_id_from_external_id(e): e + for e in external_ids + } + artifacts_by_artifact_ids_fn = ( + self.get_downstream_artifacts_by_artifact_ids + if downstream + else self.get_upstream_artifacts_by_artifact_ids + ) + artifacts_and_types_by_artifact_id = artifacts_by_artifact_ids_fn( + list(external_id_by_id.keys()), + max_num_hops, + filter_query, + event_filter, + store, + ) + + pipeline_owner = pipeline_asset.split('/')[0] + pipeline_name = pipeline_asset.split('/')[1] + artifacts_by_external_id = {} + for ( + artifact_id, + artifacts_and_types, + ) in artifacts_and_types_by_artifact_id.items(): + external_id = external_id_by_id[artifact_id] + imported_artifacts_and_types = [] + for a, t in artifacts_and_types: + imported_artifact = external_artifact_utils.cold_import_artifacts( + t, [a], pipeline_owner, pipeline_name + )[0] + imported_artifacts_and_types.append( + (imported_artifact.mlmd_artifact, imported_artifact.artifact_type) + ) + artifacts_by_external_id[external_id] = imported_artifacts_and_types + + result.update(artifacts_by_external_id) + + return result def get_downstream_artifacts_by_artifacts( self, @@ -81,7 +175,7 @@ def get_downstream_artifacts_by_artifacts( ]: """Given a list of artifacts, get their provenance successor artifacts. - For each artifact matched by a given `artifact_id`, treat it as a starting + For each provided artifact, treat it as a starting artifact and get artifacts that are connected to them within `max_num_hops` via a path in the downstream direction like: artifact_i -> INPUT_event -> execution_j -> OUTPUT_event -> artifact_k. @@ -95,7 +189,7 @@ def get_downstream_artifacts_by_artifacts( Args: artifacts: a list of starting artifacts. At most 100 ids are supported. - Returns empty result if `artifact_ids` is empty. + Returns empty result if `artifacts` is empty. max_num_hops: maximum number of hops performed for downstream tracing. `max_num_hops` cannot exceed 100 nor be negative. filter_query: a query string filtering downstream artifacts by their own @@ -128,76 +222,24 @@ def get_downstream_artifacts_by_artifacts( internal_artifact_ids = [a.id for a in artifacts if not a.external_id] external_artifact_ids = [a.external_id for a in artifacts if a.external_id] + if internal_artifact_ids and external_artifact_ids: + raise ValueError( + 'Provided artifacts contain both internal and external artifacts. It' + ' is not supported.' + ) if not external_artifact_ids: return self.get_downstream_artifacts_by_artifact_ids( internal_artifact_ids, max_num_hops, filter_query, event_filter ) - if not self._mlmd_connection_manager: - raise ValueError( - 'mlmd_connection_manager is not initialized. There are external' - 'artifacts, so we need it to query the external MLMD instance.' - ) - - store_by_pipeline_asset: Dict[str, mlmd.MetadataStore] = {} - external_ids_by_pipeline_asset: Dict[str, List[str]] = ( - collections.defaultdict(list) + return self._get_external_upstream_or_downstream_artifacts( + external_artifact_ids, + max_num_hops, + filter_query, + event_filter, + downstream=True, ) - for external_id in external_artifact_ids: - connection_config = ( - external_artifact_utils.get_external_connection_config(external_id) - ) - store = self._mlmd_connection_manager.get_mlmd_handle( - connection_config - ).store - pipeline_asset = ( - external_artifact_utils.get_pipeline_asset_from_external_id( - external_id - ) - ) - external_ids_by_pipeline_asset[pipeline_asset].append(external_id) - store_by_pipeline_asset[pipeline_asset] = store - - result = {} - # Gets artifacts from each external store. - for pipeline_asset, external_ids in external_ids_by_pipeline_asset.items(): - store = store_by_pipeline_asset[pipeline_asset] - external_id_by_id = { - external_artifact_utils.get_id_from_external_id(e): e - for e in external_ids - } - artifacts_and_types_by_artifact_id = ( - self.get_downstream_artifacts_by_artifact_ids( - list(external_id_by_id.keys()), - max_num_hops, - filter_query, - event_filter, - store, - ) - ) - - pipeline_owner = pipeline_asset.split('/')[0] - pipeline_name = pipeline_asset.split('/')[1] - artifacts_by_external_id = {} - for ( - artifact_id, - artifacts_and_types, - ) in artifacts_and_types_by_artifact_id.items(): - external_id = external_id_by_id[artifact_id] - imported_artifacts_and_types = [] - for a, t in artifacts_and_types: - imported_artifact = external_artifact_utils.cold_import_artifacts( - t, [a], pipeline_owner, pipeline_name - )[0] - imported_artifacts_and_types.append( - (imported_artifact.mlmd_artifact, imported_artifact.artifact_type) - ) - artifacts_by_external_id[external_id] = imported_artifacts_and_types - - result.update(artifacts_by_external_id) - - return result def get_downstream_artifacts_by_artifact_ids( self, @@ -416,12 +458,91 @@ def get_downstream_artifacts_by_artifact_uri( for artifact_id, subgraph in artifacts_to_subgraph.items() } + def get_upstream_artifacts_by_artifacts( + self, + artifacts: List[metadata_store_pb2.Artifact], + max_num_hops: int = _MAX_NUM_HOPS, + filter_query: str = '', + event_filter: Optional[Callable[[metadata_store_pb2.Event], bool]] = None, + ) -> Dict[ + Union[str, int], + List[Tuple[metadata_store_pb2.Artifact, metadata_store_pb2.ArtifactType]], + ]: + """Given a list of artifacts, get their provenance ancestor artifacts. + + For each provided artifact, treat it as a starting + artifact and get artifacts that are connected to them within `max_num_hops` + via a path in the upstream direction like: + artifact_i -> INPUT_event -> execution_j -> OUTPUT_event -> artifact_k. + + A hop is defined as a jump to the next node following the path of node + -> event -> next_node. + For example, in the lineage graph artifact_1 -> event -> execution_1 + -> event -> artifact_2: + artifact_2 is 2 hops away from artifact_1, and execution_1 is 1 hop away + from artifact_1. + + Args: + artifacts: a list of starting artifacts. At most 100 ids are supported. + Returns empty result if `artifacts` is empty. + max_num_hops: maximum number of hops performed for upstream tracing. + `max_num_hops` cannot exceed 100 nor be negative. + filter_query: a query string filtering upstream artifacts by their own + attributes or the attributes of immediate neighbors. Please refer to + go/mlmd-filter-query-guide for more detailed guidance. Note: if + `filter_query` is specified and `max_num_hops` is 0, it's equivalent + to getting filtered artifacts by artifact ids with `get_artifacts()`. + event_filter: an optional callable object for filtering events in the + paths towards the upstream artifacts. Only an event with + `event_filter(event)` evaluated to True will be considered as valid + and kept in the path. + + Returns: + Mapping of artifact ids to a list of upstream artifacts. + """ + if not artifacts: + return {} + + # Precondition check. + if len(artifacts) > _MAX_NUM_STARTING_NODES: + raise ValueError( + 'Number of artifacts is larger than supported value of %d.' + % _MAX_NUM_STARTING_NODES + ) + if max_num_hops > _MAX_NUM_HOPS or max_num_hops < 0: + raise ValueError( + 'Number of hops %d is larger than supported value of %d or is' + ' negative.' % (max_num_hops, _MAX_NUM_HOPS) + ) + + internal_artifact_ids = [a.id for a in artifacts if not a.external_id] + external_artifact_ids = [a.external_id for a in artifacts if a.external_id] + if internal_artifact_ids and external_artifact_ids: + raise ValueError( + 'Provided artifacts contain both internal and external artifacts. It' + ' is not supported.' + ) + + if not external_artifact_ids: + return self.get_upstream_artifacts_by_artifact_ids( + internal_artifact_ids, max_num_hops, filter_query, event_filter + ) + + return self._get_external_upstream_or_downstream_artifacts( + external_artifact_ids, + max_num_hops, + filter_query, + event_filter, + downstream=False, + ) + def get_upstream_artifacts_by_artifact_ids( self, artifact_ids: List[int], max_num_hops: int = _MAX_NUM_HOPS, filter_query: str = '', event_filter: Optional[Callable[[metadata_store_pb2.Event], bool]] = None, + store: Optional[mlmd.MetadataStore] = None, ) -> Dict[ int, List[Tuple[metadata_store_pb2.Artifact, metadata_store_pb2.ArtifactType]], @@ -454,6 +575,7 @@ def get_upstream_artifacts_by_artifact_ids( paths towards the upstream artifacts. Only an event with `event_filter(event)` evaluated to True will be considered as valid and kept in the path. + store: A metadata_store.MetadataStore instance. Returns: Mapping of artifact ids to a list of upstream artifacts. @@ -467,20 +589,25 @@ def get_upstream_artifacts_by_artifact_ids( 'Number of hops is larger than supported or is negative.' ) + if store is None: + store = self._store + if store is None: + raise ValueError('MetadataStore provided to MetadataResolver is None.') + artifact_ids_str = ','.join(str(id) for id in artifact_ids) # If `max_num_hops` is set to 0, we don't need the graph traversal. if max_num_hops == 0: if not filter_query: - artifacts = self._store.get_artifacts_by_id(artifact_ids) + artifacts = store.get_artifacts_by_id(artifact_ids) else: - artifacts = self._store.get_artifacts( + artifacts = store.get_artifacts( list_options=mlmd.ListOptions( filter_query=f'id IN ({artifact_ids_str}) AND ({filter_query})', limit=_MAX_NUM_STARTING_NODES, ) ) artifact_type_ids = [a.type_id for a in artifacts] - artifact_types = self._store.get_artifact_types_by_id(artifact_type_ids) + artifact_types = store.get_artifact_types_by_id(artifact_type_ids) artifact_type_by_id = {t.id: t for t in artifact_types} return { artifact.id: [(artifact, artifact_type_by_id[artifact.type_id])] @@ -499,7 +626,7 @@ def get_upstream_artifacts_by_artifact_ids( _EVENTS_FIELD_MASK_PATH, _ARTIFACT_TYPES_MASK_PATH, ] - lineage_graph = self._store.get_lineage_subgraph( + lineage_graph = store.get_lineage_subgraph( query_options=options, field_mask_paths=field_mask_paths, ) @@ -537,7 +664,7 @@ def get_upstream_artifacts_by_artifact_ids( ) artifact_ids_str = ','.join(str(id) for id in candidate_artifact_ids) # Send a call to metadata_store to get filtered upstream artifacts. - artifacts = self._store.get_artifacts( + artifacts = store.get_artifacts( list_options=mlmd.ListOptions( filter_query=f'id IN ({artifact_ids_str}) AND ({filter_query})' ) diff --git a/tfx/orchestration/portable/input_resolution/mlmd_resolver/metadata_resolver_test.py b/tfx/orchestration/portable/input_resolution/mlmd_resolver/metadata_resolver_test.py index a852f27ae50..9e55194e700 100644 --- a/tfx/orchestration/portable/input_resolution/mlmd_resolver/metadata_resolver_test.py +++ b/tfx/orchestration/portable/input_resolution/mlmd_resolver/metadata_resolver_test.py @@ -14,8 +14,11 @@ """Integration tests for metadata resolver.""" from typing import Dict, List from absl.testing import absltest +from tfx.orchestration import metadata +from tfx.orchestration import mlmd_connection_manager as mlmd_cm from tfx.orchestration.portable.input_resolution.mlmd_resolver import metadata_resolver from tfx.orchestration.portable.input_resolution.mlmd_resolver import metadata_resolver_utils +from tfx.types import external_artifact_utils import ml_metadata as mlmd from ml_metadata.proto import metadata_store_pb2 @@ -152,7 +155,29 @@ def setUp(self): connection_config = metadata_store_pb2.ConnectionConfig() connection_config.fake_database.SetInParent() self.store = mlmd.MetadataStore(connection_config) - self.resolver = metadata_resolver.MetadataResolver(self.store) + + self._mlmd_connection_manager = mlmd_cm.MLMDConnectionManager( + primary_connection_config=connection_config + ) + self.enter_context(self._mlmd_connection_manager) + + connection_config = metadata_store_pb2.ConnectionConfig() + connection_config.fake_database.SetInParent() + mlmd_connection = metadata.Metadata(connection_config=connection_config) + self.enter_context(mlmd_connection) + self.ext_store = mlmd_connection.store + ext_connection_config = ( + external_artifact_utils.get_external_connection_config( + 'mlmd://prod:owner/ext-pipeline:artifact:1' + ) + ) + self._mlmd_connection_manager._mlmd_handle_by_config_id[ + self._mlmd_connection_manager._get_identifier(ext_connection_config) + ] = mlmd_connection + + self.resolver = metadata_resolver.MetadataResolver( + self.store, mlmd_connection_manager=self._mlmd_connection_manager + ) self.exp_type = create_artifact_type(self.store, 'Examples') self.example_gen_type = create_execution_type(self.store, 'ExampleGen') @@ -242,6 +267,164 @@ def setUp(self): contexts=[self.pipe_ctx, self.run3_ctx, self.evaluator_ctx], ) + # The following lines create artifacts in the external store. + self.ext_exp_type = create_artifact_type(self.ext_store, 'ext-Examples') + self.ext_example_gen_type = create_execution_type( + self.ext_store, 'ext-ExampleGen' + ) + self.ext_trainer_type = create_execution_type(self.ext_store, 'ext-Trainer') + self.ext_model_type = create_artifact_type(self.ext_store, 'ext-Model') + self.ext_evaluator_type = create_execution_type( + self.ext_store, 'ext-Evaluator' + ) + self.ext_evaluation_type = create_artifact_type( + self.ext_store, 'ext-Evaluation' + ) + self.ext_pipe_type = create_context_type(self.ext_store, 'pipeline') + self.ext_run_type = create_context_type(self.ext_store, 'pipeline_run') + self.ext_node_type = create_context_type(self.ext_store, 'node') + + self.ext_pipe_ctx = create_context( + self.ext_store, self.pipe_type.id, 'ext-pipeline' + ) + self.ext_run_ctx = create_context( + self.ext_store, self.run_type.id, 'ext-pipeline.run-01' + ) + self.ext_example_gen_ctx = create_context( + self.ext_store, self.node_type.id, 'ext-pipeline.ExampleGen' + ) + self.ext_trainer_ctx = create_context( + self.ext_store, self.node_type.id, 'ext-pipeline.Trainer' + ) + self.ext_evaluator_ctx = create_context( + self.ext_store, self.node_type.id, 'ext-pipeline.Evaluator' + ) + + self.ext_e1 = create_artifact( + self.ext_store, self.ext_exp_type.id, name='ext-Example-1' + ) + self.ext_m1 = create_artifact( + self.ext_store, self.ext_model_type.id, name='ext-Model-1' + ) + self.ext_ev1 = create_artifact( + self.ext_store, self.ext_evaluation_type.id, name='ext-Evaluation-1' + ) + + self.ext_expgen1 = create_execution( + self.ext_store, + self.ext_example_gen_type.id, + name='ExampleGen-1', + inputs={}, + outputs={'examples': [self.ext_e1]}, + contexts=[ + self.ext_pipe_ctx, + self.ext_run_ctx, + self.ext_example_gen_ctx, + ], + ) + self.ext_trainer1 = create_execution( + self.ext_store, + self.ext_trainer_type.id, + name='Trainer-1', + inputs={'examples': [self.ext_e1]}, + outputs={'model': [self.ext_m1]}, + contexts=[self.ext_pipe_ctx, self.ext_run_ctx, self.ext_trainer_ctx], + ) + self.ext_evaluator1 = create_execution( + self.ext_store, + self.ext_evaluator_type.id, + name='Evaluator-1', + inputs={'examples': [self.ext_e1], 'model': [self.ext_m1]}, + outputs={'evaluation': [self.ext_ev1]}, + contexts=[self.ext_pipe_ctx, self.ext_run_ctx, self.ext_evaluator_ctx], + ) + + def test_get_downstream_artifacts_by_artifacts(self): + # Test: get downstream artifacts by exp1, with max_num_hops = 0 + result_from_exp1 = self.resolver.get_downstream_artifacts_by_artifacts( + [self.e1], max_num_hops=0 + ) + self.assertLen(result_from_exp1, 1) + self.assertIn(self.e1.id, result_from_exp1) + self.assertCountEqual( + [result_from_exp1[self.e1.id][0][0].name], [self.e1.name] + ) + + # Test: get downstream artifacts by exp1, with max_num_hops = 20 + result_from_exp1 = self.resolver.get_downstream_artifacts_by_artifacts( + [self.e1], max_num_hops=20 + ) + self.assertLen(result_from_exp1, 1) + self.assertIn(self.e1.id, result_from_exp1) + self.assertCountEqual( + [(a.name, t.name) for a, t in result_from_exp1[self.e1.id]], + [ + (self.e1.name, self.exp_type.name), + (self.m1.name, self.model_type.name), + (self.ev1.name, self.evaluation_type.name), + ], + ) + + def test_get_downstream_artifacts_by_artifacts_external(self): + ext_e1_imported = external_artifact_utils.cold_import_artifacts( + self.ext_exp_type, + [self.ext_e1], + project_owner='owner', + project_name='ext-pipeline', + )[0] + + # Get downstream artifacts by external artifact ext_e1, max_num_hops = 0 + result_from_ext_e1 = self.resolver.get_downstream_artifacts_by_artifacts( + [ext_e1_imported.mlmd_artifact], max_num_hops=0 + ) + self.assertLen(result_from_ext_e1, 1) + self.assertIn( + 'mlmd://prod:owner/ext-pipeline:artifact:1', result_from_ext_e1 + ) + self.assertCountEqual( + [ + result_from_ext_e1['mlmd://prod:owner/ext-pipeline:artifact:1'][0][ + 0 + ].name + ], + [self.ext_e1.name], + ) + + # Get downstream artifacts by external artifact ext_e1, max_num_hops = 20 + result_from_ext_e1 = self.resolver.get_downstream_artifacts_by_artifacts( + [ext_e1_imported.mlmd_artifact], max_num_hops=20 + ) + self.assertLen(result_from_ext_e1, 1) + self.assertIn( + 'mlmd://prod:owner/ext-pipeline:artifact:1', result_from_ext_e1 + ) + self.assertCountEqual( + [ + (a.name, t.name) + for a, t in result_from_ext_e1[ + 'mlmd://prod:owner/ext-pipeline:artifact:1' + ] + ], + [ + (self.ext_e1.name, self.ext_exp_type.name), + (self.ext_m1.name, self.ext_model_type.name), + (self.ext_ev1.name, self.ext_evaluation_type.name), + ], + ) + + def test_get_downstream_artifacts_by_artifacts_mixed(self): + ext_e1_imported = external_artifact_utils.cold_import_artifacts( + self.ext_exp_type, + [self.ext_e1], + project_owner='owner', + project_name='ext-pipeline', + )[0] + + with self.assertRaises(ValueError): + self.resolver.get_downstream_artifacts_by_artifacts( + [ext_e1_imported.mlmd_artifact, self.e1], max_num_hops=0 + ) + def test_get_downstream_artifacts_by_artifact_ids(self): # Test: get downstream artifacts by example_1, with max_num_hops = 0 result_from_exp1 = self.resolver.get_downstream_artifacts_by_artifact_ids( @@ -624,6 +807,78 @@ def _is_input_event_or_valid_output_event( [(self.m1.name, self.model_type.name)], ) + def test_get_upstream_artifacts_by_artifacts(self): + # Test: get upstream artifacts by m1, with max_num_hops = 0 + result_from_m1 = self.resolver.get_upstream_artifacts_by_artifacts( + [self.m1], max_num_hops=0 + ) + self.assertLen(result_from_m1, 1) + self.assertIn(self.m1.id, result_from_m1) + self.assertCountEqual( + [result_from_m1[self.m1.id][0][0].name], [self.m1.name] + ) + + # Test: get upstream artifacts by m1, with max_num_hops = 20 + result_from_m1 = self.resolver.get_upstream_artifacts_by_artifacts( + [self.m1], max_num_hops=20 + ) + self.assertLen(result_from_m1, 1) + self.assertIn(self.m1.id, result_from_m1) + self.assertCountEqual( + [(a.name, t.name) for a, t in result_from_m1[self.m1.id]], + [ + (self.e1.name, self.exp_type.name), + (self.m1.name, self.model_type.name), + (self.e2.name, self.exp_type.name), + ], + ) + + def test_get_upstream_artifacts_by_artifacts_external(self): + ext_m1_imported = external_artifact_utils.cold_import_artifacts( + self.ext_model_type, + [self.ext_m1], + project_owner='owner', + project_name='ext-pipeline', + )[0] + + # Test: get upstream artifacts by external artifact ext_m1, max_num_hops=0 + result_from_ext_m1 = self.resolver.get_upstream_artifacts_by_artifacts( + [ext_m1_imported.mlmd_artifact], max_num_hops=0 + ) + self.assertLen(result_from_ext_m1, 1) + self.assertIn( + 'mlmd://prod:owner/ext-pipeline:artifact:2', result_from_ext_m1 + ) + self.assertCountEqual( + [ + result_from_ext_m1['mlmd://prod:owner/ext-pipeline:artifact:2'][0][ + 0 + ].name + ], + [self.ext_m1.name], + ) + + # Test: get upstream artifacts by external artifact ext_m1, max_num_hops=20 + result_from_ext_m1 = self.resolver.get_upstream_artifacts_by_artifacts( + [ext_m1_imported.mlmd_artifact], max_num_hops=20 + ) + self.assertLen(result_from_ext_m1, 1) + self.assertIn( + 'mlmd://prod:owner/ext-pipeline:artifact:2', result_from_ext_m1 + ) + self.assertCountEqual( + [ + (a.name, t.name) + for a, t in result_from_ext_m1[ + 'mlmd://prod:owner/ext-pipeline:artifact:2' + ] + ], + [ + (self.ext_e1.name, self.ext_exp_type.name), + (self.ext_m1.name, self.ext_model_type.name), + ], + ) + def test_get_upstream_artifacts_by_artifact_ids(self): # Test: get upstream artifacts by model_1, with max_num_hops = 0 result_from_m1 = self.resolver.get_upstream_artifacts_by_artifact_ids(