From 3b1b76244dfc6121e866f9d8803f1491ee6f5cec Mon Sep 17 00:00:00 2001 From: Shirshanka Das Date: Sat, 19 Oct 2024 14:53:28 -0700 Subject: [PATCH] feat(sdk):platform-resource - complex queries (#11675) --- .../platformresource/platform_resource.py | 193 ++++++------ .../src/datahub/utilities/openapi_utils.py | 69 +++++ .../src/datahub/utilities/search_utils.py | 285 ++++++++++++++++++ .../test_platform_resource.py | 15 + .../tests/unit/utilities/test_search_utils.py | 71 +++++ .../test_platform_resource.py | 78 ++++- 6 files changed, 617 insertions(+), 94 deletions(-) create mode 100644 metadata-ingestion/src/datahub/utilities/openapi_utils.py create mode 100644 metadata-ingestion/src/datahub/utilities/search_utils.py create mode 100644 metadata-ingestion/tests/unit/utilities/test_search_utils.py diff --git a/metadata-ingestion/src/datahub/api/entities/platformresource/platform_resource.py b/metadata-ingestion/src/datahub/api/entities/platformresource/platform_resource.py index 349b0ff11d84f..0f7b10a067053 100644 --- a/metadata-ingestion/src/datahub/api/entities/platformresource/platform_resource.py +++ b/metadata-ingestion/src/datahub/api/entities/platformresource/platform_resource.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, Iterable, List, Optional, Union +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, Union, cast from avrogen.dict_wrapper import DictWrapper from pydantic import BaseModel @@ -14,7 +14,14 @@ from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.emitter.mcp_builder import DatahubKey from datahub.ingestion.graph.client import DataHubGraph -from datahub.metadata.urns import PlatformResourceUrn +from datahub.metadata.urns import DataPlatformUrn, PlatformResourceUrn, Urn +from datahub.utilities.openapi_utils import OpenAPIGraphClient +from datahub.utilities.search_utils import ( + ElasticDocumentQuery, + ElasticsearchQueryBuilder, + LogicalOperator, + SearchField, +) logger = logging.getLogger(__name__) @@ -69,71 +76,75 @@ def to_resource_info(self) -> models.PlatformResourceInfoClass: ) -class OpenAPIGraphClient: +class DataPlatformInstanceUrn: + """ + A simple implementation of a URN class for DataPlatformInstance. + Since this is not present in the URN registry, we need to implement it here. + """ - ENTITY_KEY_ASPECT_MAP = { - aspect_type.ASPECT_INFO.get("keyForEntity"): name - for name, aspect_type in models.ASPECT_NAME_MAP.items() - if aspect_type.ASPECT_INFO.get("keyForEntity") - } + @staticmethod + def create_from_id(platform_instance_urn: str) -> Urn: + if platform_instance_urn.startswith("urn:li:platformInstance:"): + string_urn = platform_instance_urn + else: + string_urn = f"urn:li:platformInstance:{platform_instance_urn}" + return Urn.from_string(string_urn) - def __init__(self, graph: DataHubGraph): - self.graph = graph - self.openapi_base = graph._gms_server.rstrip("/") + "/openapi/v3" - def scroll_urns_by_filter( - self, - entity_type: str, - extra_or_filters: List[Dict[str, str]], - extra_and_filters: List[Dict[str, str]] = [], - ) -> Iterable[str]: - """ - Scroll through all urns that match the given filters - """ +class UrnSearchField(SearchField): + """ + A search field that supports URN values. + TODO: Move this to search_utils after we make this more generic. + """ - key_aspect = self.ENTITY_KEY_ASPECT_MAP.get(entity_type) - assert key_aspect, f"No key aspect found for entity type {entity_type}" - if extra_or_filters and extra_and_filters: - raise ValueError( - "Only one of extra_or_filters and extra_and_filters should be provided" - ) + def __init__(self, field_name: str, urn_value_extractor: Callable[[str], Urn]): + self.urn_value_extractor = urn_value_extractor + super().__init__(field_name) - count = 1000 - query = ( - " OR ".join( - [ - f"{filter['field']}:\"{filter['value']}\"" - for filter in extra_or_filters - ] - ) - if extra_or_filters - else " AND ".join( - [ - f"{filter['field']}:\"{filter['value']}\"" - for filter in extra_and_filters - ] - ) + def get_search_value(self, value: str) -> str: + return str(self.urn_value_extractor(value)) + + +class PlatformResourceSearchField(SearchField): + def __init__(self, field_name: str): + super().__init__(field_name) + + @classmethod + def from_search_field( + cls, search_field: SearchField + ) -> "PlatformResourceSearchField": + # pretends to be a class method, but just returns the input + return search_field # type: ignore + + +class PlatformResourceSearchFields: + PRIMARY_KEY = PlatformResourceSearchField("primaryKey") + RESOURCE_TYPE = PlatformResourceSearchField("resourceType") + SECONDARY_KEYS = PlatformResourceSearchField("secondaryKeys") + PLATFORM = PlatformResourceSearchField.from_search_field( + UrnSearchField( + field_name="platform.keyword", + urn_value_extractor=DataPlatformUrn.create_from_id, ) - scroll_id = None - while True: - response = self.graph._get_generic( - self.openapi_base + f"/entity/{entity_type.lower()}", - params={ - "systemMetadata": "false", - "includeSoftDelete": "false", - "skipCache": "false", - "aspects": [key_aspect], - "scrollId": scroll_id, - "count": count, - "query": query, - }, - ) - entities = response.get("entities", []) - scroll_id = response.get("scrollId") - for entity in entities: - yield entity["urn"] - if not scroll_id: - break + ) + PLATFORM_INSTANCE = PlatformResourceSearchField.from_search_field( + UrnSearchField( + field_name="platformInstance.keyword", + urn_value_extractor=DataPlatformInstanceUrn.create_from_id, + ) + ) + + +class ElasticPlatformResourceQuery(ElasticDocumentQuery[PlatformResourceSearchField]): + def __init__(self): + super().__init__() + + @classmethod + def create_from( + cls: Type["ElasticPlatformResourceQuery"], + *args: Tuple[Union[str, PlatformResourceSearchField], str], + ) -> "ElasticPlatformResourceQuery": + return cast(ElasticPlatformResourceQuery, super().create_from(*args)) class PlatformResource(BaseModel): @@ -147,6 +158,12 @@ def remove( cls, key: PlatformResourceKey, ) -> "PlatformResource": + """ + Creates a PlatformResource object with the removed status set to True. + Removed PlatformResource objects are used to soft-delete resources from + the graph. + To hard-delete a resource, use the delete method. + """ return cls( id=key.id, removed=True, @@ -240,28 +257,38 @@ def from_datahub( @staticmethod def search_by_key( - graph_client: DataHubGraph, key: str, primary: bool = True + graph_client: DataHubGraph, + key: str, + primary: bool = True, + is_exact: bool = True, ) -> Iterable["PlatformResource"]: - extra_or_filters = [] - extra_or_filters.append( - { - "field": "primaryKey", - "condition": "EQUAL", - "value": key, - } + """ + Searches for PlatformResource entities by primary or secondary key. + + :param graph_client: DataHubGraph client + :param key: The key to search for + :param primary: Whether to search for primary only or expand the search + to secondary keys (default: True) + :param is_exact: Whether to search for an exact match (default: True) + :return: An iterable of PlatformResource objects + """ + + elastic_platform_resource_group = ( + ElasticPlatformResourceQuery.create_from() + .group(LogicalOperator.OR) + .add_field_match( + PlatformResourceSearchFields.PRIMARY_KEY, key, is_exact=is_exact + ) ) if not primary: # we expand the search to secondary keys - extra_or_filters.append( - { - "field": "secondaryKeys", - "condition": "EQUAL", - "value": key, - } + elastic_platform_resource_group.add_field_match( + PlatformResourceSearchFields.SECONDARY_KEYS, key, is_exact=is_exact ) + query = elastic_platform_resource_group.end() openapi_client = OpenAPIGraphClient(graph_client) for urn in openapi_client.scroll_urns_by_filter( entity_type="platformResource", - extra_or_filters=extra_or_filters, + query=query, ): platform_resource = PlatformResource.from_datahub(graph_client, urn) if platform_resource: @@ -273,18 +300,16 @@ def delete(self, graph_client: DataHubGraph, hard: bool = True) -> None: @staticmethod def search_by_filters( graph_client: DataHubGraph, - and_filters: List[Dict[str, str]] = [], - or_filters: List[Dict[str, str]] = [], + query: Union[ + ElasticPlatformResourceQuery, + ElasticDocumentQuery, + ElasticsearchQueryBuilder, + ], ) -> Iterable["PlatformResource"]: - if and_filters and or_filters: - raise ValueError( - "Only one of and_filters and or_filters should be provided" - ) openapi_client = OpenAPIGraphClient(graph_client) for urn in openapi_client.scroll_urns_by_filter( entity_type="platformResource", - extra_or_filters=or_filters if or_filters else [], - extra_and_filters=and_filters if and_filters else [], + query=query, ): platform_resource = PlatformResource.from_datahub(graph_client, urn) if platform_resource: diff --git a/metadata-ingestion/src/datahub/utilities/openapi_utils.py b/metadata-ingestion/src/datahub/utilities/openapi_utils.py new file mode 100644 index 0000000000000..e704ff7f84cbb --- /dev/null +++ b/metadata-ingestion/src/datahub/utilities/openapi_utils.py @@ -0,0 +1,69 @@ +import logging +from typing import Iterable, Union + +import datahub.metadata.schema_classes as models +from datahub.ingestion.graph.client import DataHubGraph +from datahub.utilities.search_utils import ( + ElasticDocumentQuery, + ElasticsearchQueryBuilder, +) + +logger = logging.getLogger(__name__) + + +class OpenAPIGraphClient: + """ + An experimental client for the DataHubGraph that uses the OpenAPI endpoints + to query entities and aspects. + Does not support all features of the DataHubGraph. + API is subject to change. + + DO NOT USE THIS UNLESS YOU KNOW WHAT YOU ARE DOING. + """ + + ENTITY_KEY_ASPECT_MAP = { + aspect_type.ASPECT_INFO.get("keyForEntity"): name + for name, aspect_type in models.ASPECT_NAME_MAP.items() + if aspect_type.ASPECT_INFO.get("keyForEntity") + } + + def __init__(self, graph: DataHubGraph): + self.graph = graph + self.openapi_base = graph._gms_server.rstrip("/") + "/openapi/v3" + + def scroll_urns_by_filter( + self, + entity_type: str, + query: Union[ElasticDocumentQuery, ElasticsearchQueryBuilder], + ) -> Iterable[str]: + """ + Scroll through all urns that match the given filters. + + """ + + key_aspect = self.ENTITY_KEY_ASPECT_MAP.get(entity_type) + assert key_aspect, f"No key aspect found for entity type {entity_type}" + + count = 1000 + string_query = query.build() + scroll_id = None + logger.debug(f"Scrolling with query: {string_query}") + while True: + response = self.graph._get_generic( + self.openapi_base + f"/entity/{entity_type.lower()}", + params={ + "systemMetadata": "false", + "includeSoftDelete": "false", + "skipCache": "false", + "aspects": [key_aspect], + "scrollId": scroll_id, + "count": count, + "query": string_query, + }, + ) + entities = response.get("entities", []) + scroll_id = response.get("scrollId") + for entity in entities: + yield entity["urn"] + if not scroll_id: + break diff --git a/metadata-ingestion/src/datahub/utilities/search_utils.py b/metadata-ingestion/src/datahub/utilities/search_utils.py new file mode 100644 index 0000000000000..0bd88addd8660 --- /dev/null +++ b/metadata-ingestion/src/datahub/utilities/search_utils.py @@ -0,0 +1,285 @@ +import logging +import re +from enum import Enum +from typing import Generic, List, Optional, Tuple, Type, TypeVar, Union + +logger = logging.getLogger(__name__) + + +class LogicalOperator(Enum): + AND = "AND" + OR = "OR" + + +class SearchField: + def __init__(self, field_name: str): + self.field_name = field_name + + def get_search_value(self, value: str) -> str: + return value + + def __str__(self) -> str: + return self.field_name + + def __repr__(self) -> str: + return self.__str__() + + @classmethod + def from_string_field(cls, field_name: str) -> "SearchField": + return cls(field_name) + + +class QueryNode: + def __init__(self, operator: Optional[LogicalOperator] = None): + self.operator = operator + self.children: List[Union[QueryNode, str]] = [] + + def add_child(self, child: Union["QueryNode", str]) -> None: + self.children.append(child) + + def build(self) -> str: + if not self.children: + return "" + + if self.operator is None: + return ( + self.children[0] + if isinstance(self.children[0], str) + else self.children[0].build() + ) + + child_queries = [] + for child in self.children: + if isinstance(child, str): + child_queries.append(child) + else: + child_queries.append(child.build()) + + joined_queries = f" {self.operator.value} ".join(child_queries) + return f"({joined_queries})" if len(child_queries) > 1 else joined_queries + + +class ElasticsearchQueryBuilder: + SPECIAL_CHARACTERS = r'+-=&|> None: + self.root = QueryNode(operator=operator) + + @classmethod + def escape_special_characters(cls, value: str) -> str: + """ + Escape special characters in the search term. + """ + return re.sub(f"([{re.escape(cls.SPECIAL_CHARACTERS)}])", r"\\\1", value) + + def _create_term( + self, field: SearchField, value: str, is_exact: bool = False + ) -> str: + escaped_value = self.escape_special_characters(field.get_search_value(value)) + field_name: str = field.field_name + if is_exact: + return f'{field_name}:"{escaped_value}"' + return f"{field_name}:{escaped_value}" + + def add_field_match( + self, field: SearchField, value: str, is_exact: bool = True + ) -> "ElasticsearchQueryBuilder": + term = self._create_term(field, value, is_exact) + self.root.add_child(term) + return self + + def add_field_not_match( + self, field: SearchField, value: str, is_exact: bool = True + ) -> "ElasticsearchQueryBuilder": + term = f"-{self._create_term(field, value, is_exact)}" + self.root.add_child(term) + return self + + def add_range( + self, + field: str, + min_value: Optional[str] = None, + max_value: Optional[str] = None, + include_min: bool = True, + include_max: bool = True, + ) -> "ElasticsearchQueryBuilder": + min_bracket = "[" if include_min else "{" + max_bracket = "]" if include_max else "}" + min_val = min_value if min_value is not None else "*" + max_val = max_value if max_value is not None else "*" + range_query = f"{field}:{min_bracket}{min_val} TO {max_val}{max_bracket}" + self.root.add_child(range_query) + return self + + def add_wildcard(self, field: str, pattern: str) -> "ElasticsearchQueryBuilder": + wildcard_query = f"{field}:{pattern}" + self.root.add_child(wildcard_query) + return self + + def add_fuzzy( + self, field: str, value: str, fuzziness: int = 2 + ) -> "ElasticsearchQueryBuilder": + fuzzy_query = f"{field}:{value}~{fuzziness}" + self.root.add_child(fuzzy_query) + return self + + def add_boost( + self, field: str, value: str, boost: float + ) -> "ElasticsearchQueryBuilder": + boosted_query = f"{field}:{value}^{boost}" + self.root.add_child(boosted_query) + return self + + def group(self, operator: LogicalOperator) -> "QueryGroup": + return QueryGroup(self, operator) + + def build(self) -> str: + return self.root.build() + + +class QueryGroup: + def __init__(self, parent: ElasticsearchQueryBuilder, operator: LogicalOperator): + self.parent = parent + self.node = QueryNode(operator) + self.parent.root.add_child(self.node) + + def add_field_match( + self, field: Union[str, SearchField], value: str, is_exact: bool = True + ) -> "QueryGroup": + if isinstance(field, str): + field = SearchField.from_string_field(field) + term = self.parent._create_term(field, value, is_exact) + self.node.add_child(term) + return self + + def add_field_not_match( + self, field: Union[str, SearchField], value: str, is_exact: bool = True + ) -> "QueryGroup": + if isinstance(field, str): + field = SearchField.from_string_field(field) + term = f"-{self.parent._create_term(field, value, is_exact)}" + self.node.add_child(term) + return self + + def add_range( + self, + field: str, + min_value: Optional[str] = None, + max_value: Optional[str] = None, + include_min: bool = True, + include_max: bool = True, + ) -> "QueryGroup": + min_bracket = "[" if include_min else "{" + max_bracket = "]" if include_max else "}" + min_val = min_value if min_value is not None else "*" + max_val = max_value if max_value is not None else "*" + range_query = f"{field}:{min_bracket}{min_val} TO {max_val}{max_bracket}" + self.node.add_child(range_query) + return self + + def add_wildcard(self, field: str, pattern: str) -> "QueryGroup": + wildcard_query = f"{field}:{pattern}" + self.node.add_child(wildcard_query) + return self + + def add_fuzzy(self, field: str, value: str, fuzziness: int = 2) -> "QueryGroup": + fuzzy_query = f"{field}:{value}~{fuzziness}" + self.node.add_child(fuzzy_query) + return self + + def add_boost(self, field: str, value: str, boost: float) -> "QueryGroup": + boosted_query = f"{field}:{value}^{boost}" + self.node.add_child(boosted_query) + return self + + def group(self, operator: LogicalOperator) -> "QueryGroup": + new_group = QueryGroup(self.parent, operator) + self.node.add_child(new_group.node) + return new_group + + def end(self) -> ElasticsearchQueryBuilder: + return self.parent + + +SF = TypeVar("SF", bound=SearchField) + + +class ElasticDocumentQuery(Generic[SF]): + def __init__(self) -> None: + self.query_builder = ElasticsearchQueryBuilder() + + @classmethod + def create_from( + cls: Type["ElasticDocumentQuery[SF]"], + *args: Tuple[Union[str, SF], str], + ) -> "ElasticDocumentQuery[SF]": + instance = cls() + for arg in args: + if isinstance(arg, SearchField): + # If the value is empty, we treat it as a wildcard search + logger.info(f"Adding wildcard search for field {arg}") + instance.add_wildcard(arg, "*") + elif isinstance(arg, tuple) and len(arg) == 2: + field, value = arg + assert isinstance(value, str) + if isinstance(field, SearchField): + instance.add_field_match(field, value) + elif isinstance(field, str): + instance.add_field_match( + SearchField.from_string_field(field), value + ) + else: + raise ValueError("Invalid field type {}".format(type(field))) + return instance + + def add_field_match( + self, field: Union[str, SearchField], value: str, is_exact: bool = True + ) -> "ElasticDocumentQuery": + if isinstance(field, str): + field = SearchField.from_string_field(field) + self.query_builder.add_field_match(field, value, is_exact) + return self + + def add_field_not_match( + self, field: SearchField, value: str, is_exact: bool = True + ) -> "ElasticDocumentQuery": + self.query_builder.add_field_not_match(field, value, is_exact) + return self + + def add_range( + self, + field: SearchField, + min_value: Optional[str] = None, + max_value: Optional[str] = None, + include_min: bool = True, + include_max: bool = True, + ) -> "ElasticDocumentQuery": + field_name: str = field.field_name # type: ignore + self.query_builder.add_range( + field_name, min_value, max_value, include_min, include_max + ) + return self + + def add_wildcard(self, field: SearchField, pattern: str) -> "ElasticDocumentQuery": + field_name: str = field.field_name # type: ignore + self.query_builder.add_wildcard(field_name, pattern) + return self + + def add_fuzzy( + self, field: SearchField, value: str, fuzziness: int = 2 + ) -> "ElasticDocumentQuery": + field_name: str = field.field_name # type: ignore + self.query_builder.add_fuzzy(field_name, value, fuzziness) + return self + + def add_boost( + self, field: SearchField, value: str, boost: float + ) -> "ElasticDocumentQuery": + self.query_builder.add_boost(field.field_name, value, boost) + return self + + def group(self, operator: LogicalOperator) -> QueryGroup: + return self.query_builder.group(operator) + + def build(self) -> str: + return self.query_builder.build() diff --git a/metadata-ingestion/tests/unit/api/entities/platformresource/test_platform_resource.py b/metadata-ingestion/tests/unit/api/entities/platformresource/test_platform_resource.py index e6c9a9466d62b..a84e373dbe72c 100644 --- a/metadata-ingestion/tests/unit/api/entities/platformresource/test_platform_resource.py +++ b/metadata-ingestion/tests/unit/api/entities/platformresource/test_platform_resource.py @@ -4,9 +4,12 @@ import datahub.metadata.schema_classes as models from datahub.api.entities.platformresource.platform_resource import ( + ElasticPlatformResourceQuery, PlatformResource, PlatformResourceKey, + PlatformResourceSearchFields, ) +from datahub.utilities.search_utils import LogicalOperator def test_platform_resource_dict(): @@ -179,3 +182,15 @@ class TestModel(BaseModel): ).encode("utf-8") assert platform_resource_info_mcp.aspect.value.schemaType == "JSON" assert platform_resource_info_mcp.aspect.value.schemaRef == TestModel.__name__ + + +def test_platform_resource_filters(): + + query = ( + ElasticPlatformResourceQuery.create_from() + .group(LogicalOperator.AND) + .add_field_match(PlatformResourceSearchFields.PRIMARY_KEY, "test_1") + .add_field_match(PlatformResourceSearchFields.RESOURCE_TYPE, "server") + .end() + ) + assert query.build() == '(primaryKey:"test_1" AND resourceType:"server")' diff --git a/metadata-ingestion/tests/unit/utilities/test_search_utils.py b/metadata-ingestion/tests/unit/utilities/test_search_utils.py new file mode 100644 index 0000000000000..6fa2e46c7f20e --- /dev/null +++ b/metadata-ingestion/tests/unit/utilities/test_search_utils.py @@ -0,0 +1,71 @@ +from datahub.utilities.search_utils import ( + ElasticDocumentQuery, + LogicalOperator, + SearchField, +) + + +def test_simple_and_filters(): + query = ( + ElasticDocumentQuery.create_from() + .group(LogicalOperator.AND) + .add_field_match("field1", "value1") + .add_field_match("field2", "value2") + .end() + ) + + assert query.build() == '(field1:"value1" AND field2:"value2")' + + +def test_simple_or_filters(): + query = ( + ElasticDocumentQuery.create_from() + .group(LogicalOperator.OR) + .add_field_match("field1", "value1") + .add_field_match("field2", "value2") + .end() + ) + + assert query.build() == '(field1:"value1" OR field2:"value2")' + + # Use SearchFilter to create this query + query = ( + ElasticDocumentQuery.create_from() + .group(LogicalOperator.OR) + .add_field_match(SearchField.from_string_field("field1"), "value1") + .add_field_match(SearchField.from_string_field("field2"), "value2") + .end() + ) + assert query.build() == '(field1:"value1" OR field2:"value2")' + + +def test_simple_field_match(): + query: ElasticDocumentQuery = ElasticDocumentQuery.create_from( + ("field1", "value1:1") + ) + assert query.build() == 'field1:"value1\\:1"' + + # Another way to create the same query + query = ElasticDocumentQuery.create_from() + query.add_field_match("field1", "value1:1") + assert query.build() == 'field1:"value1\\:1"' + + +def test_negation(): + query = ( + ElasticDocumentQuery.create_from() + .group(LogicalOperator.AND) + .add_field_match("field1", "value1") + .add_field_not_match("field2", "value2") + .end() + ) + + assert query.build() == '(field1:"value1" AND -field2:"value2")' + + +def test_multi_arg_create_from(): + query: ElasticDocumentQuery = ElasticDocumentQuery.create_from( + ("field1", "value1"), + ("field2", "value2"), + ) + assert query.build() == '(field1:"value1" AND field2:"value2")' diff --git a/smoke-test/tests/platform_resources/test_platform_resource.py b/smoke-test/tests/platform_resources/test_platform_resource.py index 7ebfd4d6ea15b..39d15f2e8dea6 100644 --- a/smoke-test/tests/platform_resources/test_platform_resource.py +++ b/smoke-test/tests/platform_resources/test_platform_resource.py @@ -5,8 +5,10 @@ import pytest from datahub.api.entities.platformresource.platform_resource import ( + ElasticPlatformResourceQuery, PlatformResource, PlatformResourceKey, + PlatformResourceSearchFields, ) from tests.utils import wait_for_healthcheck_util, wait_for_writes_to_sync @@ -42,7 +44,12 @@ def cleanup_resources(graph_client): logger.warning(f"Failed to delete resource: {e}") # Additional cleanup for any resources that might have been missed - for resource in PlatformResource.search_by_key(graph_client, "test_"): + for resource in PlatformResource.search_by_filters( + graph_client, + ElasticPlatformResourceQuery.create_from().add_wildcard( + PlatformResourceSearchFields.PRIMARY_KEY, "test_*" + ), + ): try: resource.delete(graph_client) except Exception as e: @@ -114,7 +121,7 @@ def test_platform_resource_non_existent(graph_client, test_id): assert platform_resource is None -def test_platform_resource_urn_secondary_key(graph_client, test_id): +def test_platform_resource_urn_secondary_key(graph_client, test_id, cleanup_resources): key = PlatformResourceKey( platform=f"test_platform_{test_id}", resource_type=f"test_resource_type_{test_id}", @@ -129,6 +136,7 @@ def test_platform_resource_urn_secondary_key(graph_client, test_id): secondary_keys=[dataset_urn], ) platform_resource.to_datahub(graph_client) + cleanup_resources.append(platform_resource) wait_for_writes_to_sync() read_platform_resources = [ @@ -141,7 +149,9 @@ def test_platform_resource_urn_secondary_key(graph_client, test_id): assert read_platform_resources[0] == platform_resource -def test_platform_resource_listing_by_resource_type(graph_client, test_id): +def test_platform_resource_listing_by_resource_type( + graph_client, test_id, cleanup_resources +): # Generate two resources with the same resource type key1 = PlatformResourceKey( platform=f"test_platform_{test_id}", @@ -171,13 +181,9 @@ def test_platform_resource_listing_by_resource_type(graph_client, test_id): r for r in PlatformResource.search_by_filters( graph_client, - and_filters=[ - { - "field": "resourceType", - "condition": "EQUAL", - "value": key1.resource_type, - } - ], + query=ElasticPlatformResourceQuery.create_from( + (PlatformResourceSearchFields.RESOURCE_TYPE, key1.resource_type) + ), ) ] assert len(search_results) == 2 @@ -186,3 +192,55 @@ def test_platform_resource_listing_by_resource_type(graph_client, test_id): read_platform_resource_2 = next(r for r in search_results if r.id == key2.id) assert read_platform_resource_1 == platform_resource1 assert read_platform_resource_2 == platform_resource2 + + +def test_platform_resource_listing_complex_queries(graph_client, test_id): + # Generate two resources with the same resource type + key1 = PlatformResourceKey( + platform=f"test_platform1_{test_id}", + resource_type=f"test_resource_type_{test_id}", + primary_key=f"test_primary_key_1_{test_id}", + ) + platform_resource1 = PlatformResource.create( + key=key1, + value={"test_key": f"test_value_1_{test_id}"}, + ) + platform_resource1.to_datahub(graph_client) + + key2 = PlatformResourceKey( + platform=f"test_platform2_{test_id}", + resource_type=f"test_resource_type_{test_id}", + primary_key=f"test_primary_key_2_{test_id}", + ) + platform_resource2 = PlatformResource.create( + key=key2, + value={"test_key": f"test_value_2_{test_id}"}, + ) + platform_resource2.to_datahub(graph_client) + + wait_for_writes_to_sync() + from datahub.api.entities.platformresource.platform_resource import ( + ElasticPlatformResourceQuery, + LogicalOperator, + PlatformResourceSearchFields, + ) + + query = ( + ElasticPlatformResourceQuery.create_from() + .group(LogicalOperator.AND) + .add_field_match(PlatformResourceSearchFields.RESOURCE_TYPE, key1.resource_type) + .add_field_not_match(PlatformResourceSearchFields.PLATFORM, key1.platform) + .end() + ) + + search_results = [ + r + for r in PlatformResource.search_by_filters( + graph_client, + query=query, + ) + ] + assert len(search_results) == 1 + + read_platform_resource = search_results[0] + assert read_platform_resource == platform_resource2