From bda79bd489753ec017a503348aa57af617db99a9 Mon Sep 17 00:00:00 2001 From: Shirshanka Das Date: Sat, 19 Oct 2024 14:53:28 -0700 Subject: [PATCH] feat(sdk):platform-resource - complex queries (#11675) --- .../src/datahub/utilities/openapi_utils.py | 69 +++++ .../src/datahub/utilities/search_utils.py | 285 ++++++++++++++++++ .../test_platform_resource.py | 15 + .../tests/unit/utilities/test_search_utils.py | 71 +++++ .../test_platform_resource.py | 78 ++++- 5 files changed, 508 insertions(+), 10 deletions(-) create mode 100644 metadata-ingestion/src/datahub/utilities/openapi_utils.py create mode 100644 metadata-ingestion/src/datahub/utilities/search_utils.py create mode 100644 metadata-ingestion/tests/unit/utilities/test_search_utils.py diff --git a/metadata-ingestion/src/datahub/utilities/openapi_utils.py b/metadata-ingestion/src/datahub/utilities/openapi_utils.py new file mode 100644 index 0000000000000..e704ff7f84cbb --- /dev/null +++ b/metadata-ingestion/src/datahub/utilities/openapi_utils.py @@ -0,0 +1,69 @@ +import logging +from typing import Iterable, Union + +import datahub.metadata.schema_classes as models +from datahub.ingestion.graph.client import DataHubGraph +from datahub.utilities.search_utils import ( + ElasticDocumentQuery, + ElasticsearchQueryBuilder, +) + +logger = logging.getLogger(__name__) + + +class OpenAPIGraphClient: + """ + An experimental client for the DataHubGraph that uses the OpenAPI endpoints + to query entities and aspects. + Does not support all features of the DataHubGraph. + API is subject to change. + + DO NOT USE THIS UNLESS YOU KNOW WHAT YOU ARE DOING. + """ + + ENTITY_KEY_ASPECT_MAP = { + aspect_type.ASPECT_INFO.get("keyForEntity"): name + for name, aspect_type in models.ASPECT_NAME_MAP.items() + if aspect_type.ASPECT_INFO.get("keyForEntity") + } + + def __init__(self, graph: DataHubGraph): + self.graph = graph + self.openapi_base = graph._gms_server.rstrip("/") + "/openapi/v3" + + def scroll_urns_by_filter( + self, + entity_type: str, + query: Union[ElasticDocumentQuery, ElasticsearchQueryBuilder], + ) -> Iterable[str]: + """ + Scroll through all urns that match the given filters. + + """ + + key_aspect = self.ENTITY_KEY_ASPECT_MAP.get(entity_type) + assert key_aspect, f"No key aspect found for entity type {entity_type}" + + count = 1000 + string_query = query.build() + scroll_id = None + logger.debug(f"Scrolling with query: {string_query}") + while True: + response = self.graph._get_generic( + self.openapi_base + f"/entity/{entity_type.lower()}", + params={ + "systemMetadata": "false", + "includeSoftDelete": "false", + "skipCache": "false", + "aspects": [key_aspect], + "scrollId": scroll_id, + "count": count, + "query": string_query, + }, + ) + entities = response.get("entities", []) + scroll_id = response.get("scrollId") + for entity in entities: + yield entity["urn"] + if not scroll_id: + break diff --git a/metadata-ingestion/src/datahub/utilities/search_utils.py b/metadata-ingestion/src/datahub/utilities/search_utils.py new file mode 100644 index 0000000000000..0bd88addd8660 --- /dev/null +++ b/metadata-ingestion/src/datahub/utilities/search_utils.py @@ -0,0 +1,285 @@ +import logging +import re +from enum import Enum +from typing import Generic, List, Optional, Tuple, Type, TypeVar, Union + +logger = logging.getLogger(__name__) + + +class LogicalOperator(Enum): + AND = "AND" + OR = "OR" + + +class SearchField: + def __init__(self, field_name: str): + self.field_name = field_name + + def get_search_value(self, value: str) -> str: + return value + + def __str__(self) -> str: + return self.field_name + + def __repr__(self) -> str: + return self.__str__() + + @classmethod + def from_string_field(cls, field_name: str) -> "SearchField": + return cls(field_name) + + +class QueryNode: + def __init__(self, operator: Optional[LogicalOperator] = None): + self.operator = operator + self.children: List[Union[QueryNode, str]] = [] + + def add_child(self, child: Union["QueryNode", str]) -> None: + self.children.append(child) + + def build(self) -> str: + if not self.children: + return "" + + if self.operator is None: + return ( + self.children[0] + if isinstance(self.children[0], str) + else self.children[0].build() + ) + + child_queries = [] + for child in self.children: + if isinstance(child, str): + child_queries.append(child) + else: + child_queries.append(child.build()) + + joined_queries = f" {self.operator.value} ".join(child_queries) + return f"({joined_queries})" if len(child_queries) > 1 else joined_queries + + +class ElasticsearchQueryBuilder: + SPECIAL_CHARACTERS = r'+-=&|> None: + self.root = QueryNode(operator=operator) + + @classmethod + def escape_special_characters(cls, value: str) -> str: + """ + Escape special characters in the search term. + """ + return re.sub(f"([{re.escape(cls.SPECIAL_CHARACTERS)}])", r"\\\1", value) + + def _create_term( + self, field: SearchField, value: str, is_exact: bool = False + ) -> str: + escaped_value = self.escape_special_characters(field.get_search_value(value)) + field_name: str = field.field_name + if is_exact: + return f'{field_name}:"{escaped_value}"' + return f"{field_name}:{escaped_value}" + + def add_field_match( + self, field: SearchField, value: str, is_exact: bool = True + ) -> "ElasticsearchQueryBuilder": + term = self._create_term(field, value, is_exact) + self.root.add_child(term) + return self + + def add_field_not_match( + self, field: SearchField, value: str, is_exact: bool = True + ) -> "ElasticsearchQueryBuilder": + term = f"-{self._create_term(field, value, is_exact)}" + self.root.add_child(term) + return self + + def add_range( + self, + field: str, + min_value: Optional[str] = None, + max_value: Optional[str] = None, + include_min: bool = True, + include_max: bool = True, + ) -> "ElasticsearchQueryBuilder": + min_bracket = "[" if include_min else "{" + max_bracket = "]" if include_max else "}" + min_val = min_value if min_value is not None else "*" + max_val = max_value if max_value is not None else "*" + range_query = f"{field}:{min_bracket}{min_val} TO {max_val}{max_bracket}" + self.root.add_child(range_query) + return self + + def add_wildcard(self, field: str, pattern: str) -> "ElasticsearchQueryBuilder": + wildcard_query = f"{field}:{pattern}" + self.root.add_child(wildcard_query) + return self + + def add_fuzzy( + self, field: str, value: str, fuzziness: int = 2 + ) -> "ElasticsearchQueryBuilder": + fuzzy_query = f"{field}:{value}~{fuzziness}" + self.root.add_child(fuzzy_query) + return self + + def add_boost( + self, field: str, value: str, boost: float + ) -> "ElasticsearchQueryBuilder": + boosted_query = f"{field}:{value}^{boost}" + self.root.add_child(boosted_query) + return self + + def group(self, operator: LogicalOperator) -> "QueryGroup": + return QueryGroup(self, operator) + + def build(self) -> str: + return self.root.build() + + +class QueryGroup: + def __init__(self, parent: ElasticsearchQueryBuilder, operator: LogicalOperator): + self.parent = parent + self.node = QueryNode(operator) + self.parent.root.add_child(self.node) + + def add_field_match( + self, field: Union[str, SearchField], value: str, is_exact: bool = True + ) -> "QueryGroup": + if isinstance(field, str): + field = SearchField.from_string_field(field) + term = self.parent._create_term(field, value, is_exact) + self.node.add_child(term) + return self + + def add_field_not_match( + self, field: Union[str, SearchField], value: str, is_exact: bool = True + ) -> "QueryGroup": + if isinstance(field, str): + field = SearchField.from_string_field(field) + term = f"-{self.parent._create_term(field, value, is_exact)}" + self.node.add_child(term) + return self + + def add_range( + self, + field: str, + min_value: Optional[str] = None, + max_value: Optional[str] = None, + include_min: bool = True, + include_max: bool = True, + ) -> "QueryGroup": + min_bracket = "[" if include_min else "{" + max_bracket = "]" if include_max else "}" + min_val = min_value if min_value is not None else "*" + max_val = max_value if max_value is not None else "*" + range_query = f"{field}:{min_bracket}{min_val} TO {max_val}{max_bracket}" + self.node.add_child(range_query) + return self + + def add_wildcard(self, field: str, pattern: str) -> "QueryGroup": + wildcard_query = f"{field}:{pattern}" + self.node.add_child(wildcard_query) + return self + + def add_fuzzy(self, field: str, value: str, fuzziness: int = 2) -> "QueryGroup": + fuzzy_query = f"{field}:{value}~{fuzziness}" + self.node.add_child(fuzzy_query) + return self + + def add_boost(self, field: str, value: str, boost: float) -> "QueryGroup": + boosted_query = f"{field}:{value}^{boost}" + self.node.add_child(boosted_query) + return self + + def group(self, operator: LogicalOperator) -> "QueryGroup": + new_group = QueryGroup(self.parent, operator) + self.node.add_child(new_group.node) + return new_group + + def end(self) -> ElasticsearchQueryBuilder: + return self.parent + + +SF = TypeVar("SF", bound=SearchField) + + +class ElasticDocumentQuery(Generic[SF]): + def __init__(self) -> None: + self.query_builder = ElasticsearchQueryBuilder() + + @classmethod + def create_from( + cls: Type["ElasticDocumentQuery[SF]"], + *args: Tuple[Union[str, SF], str], + ) -> "ElasticDocumentQuery[SF]": + instance = cls() + for arg in args: + if isinstance(arg, SearchField): + # If the value is empty, we treat it as a wildcard search + logger.info(f"Adding wildcard search for field {arg}") + instance.add_wildcard(arg, "*") + elif isinstance(arg, tuple) and len(arg) == 2: + field, value = arg + assert isinstance(value, str) + if isinstance(field, SearchField): + instance.add_field_match(field, value) + elif isinstance(field, str): + instance.add_field_match( + SearchField.from_string_field(field), value + ) + else: + raise ValueError("Invalid field type {}".format(type(field))) + return instance + + def add_field_match( + self, field: Union[str, SearchField], value: str, is_exact: bool = True + ) -> "ElasticDocumentQuery": + if isinstance(field, str): + field = SearchField.from_string_field(field) + self.query_builder.add_field_match(field, value, is_exact) + return self + + def add_field_not_match( + self, field: SearchField, value: str, is_exact: bool = True + ) -> "ElasticDocumentQuery": + self.query_builder.add_field_not_match(field, value, is_exact) + return self + + def add_range( + self, + field: SearchField, + min_value: Optional[str] = None, + max_value: Optional[str] = None, + include_min: bool = True, + include_max: bool = True, + ) -> "ElasticDocumentQuery": + field_name: str = field.field_name # type: ignore + self.query_builder.add_range( + field_name, min_value, max_value, include_min, include_max + ) + return self + + def add_wildcard(self, field: SearchField, pattern: str) -> "ElasticDocumentQuery": + field_name: str = field.field_name # type: ignore + self.query_builder.add_wildcard(field_name, pattern) + return self + + def add_fuzzy( + self, field: SearchField, value: str, fuzziness: int = 2 + ) -> "ElasticDocumentQuery": + field_name: str = field.field_name # type: ignore + self.query_builder.add_fuzzy(field_name, value, fuzziness) + return self + + def add_boost( + self, field: SearchField, value: str, boost: float + ) -> "ElasticDocumentQuery": + self.query_builder.add_boost(field.field_name, value, boost) + return self + + def group(self, operator: LogicalOperator) -> QueryGroup: + return self.query_builder.group(operator) + + def build(self) -> str: + return self.query_builder.build() diff --git a/metadata-ingestion/tests/unit/api/entities/platformresource/test_platform_resource.py b/metadata-ingestion/tests/unit/api/entities/platformresource/test_platform_resource.py index e6c9a9466d62b..a84e373dbe72c 100644 --- a/metadata-ingestion/tests/unit/api/entities/platformresource/test_platform_resource.py +++ b/metadata-ingestion/tests/unit/api/entities/platformresource/test_platform_resource.py @@ -4,9 +4,12 @@ import datahub.metadata.schema_classes as models from datahub.api.entities.platformresource.platform_resource import ( + ElasticPlatformResourceQuery, PlatformResource, PlatformResourceKey, + PlatformResourceSearchFields, ) +from datahub.utilities.search_utils import LogicalOperator def test_platform_resource_dict(): @@ -179,3 +182,15 @@ class TestModel(BaseModel): ).encode("utf-8") assert platform_resource_info_mcp.aspect.value.schemaType == "JSON" assert platform_resource_info_mcp.aspect.value.schemaRef == TestModel.__name__ + + +def test_platform_resource_filters(): + + query = ( + ElasticPlatformResourceQuery.create_from() + .group(LogicalOperator.AND) + .add_field_match(PlatformResourceSearchFields.PRIMARY_KEY, "test_1") + .add_field_match(PlatformResourceSearchFields.RESOURCE_TYPE, "server") + .end() + ) + assert query.build() == '(primaryKey:"test_1" AND resourceType:"server")' diff --git a/metadata-ingestion/tests/unit/utilities/test_search_utils.py b/metadata-ingestion/tests/unit/utilities/test_search_utils.py new file mode 100644 index 0000000000000..6fa2e46c7f20e --- /dev/null +++ b/metadata-ingestion/tests/unit/utilities/test_search_utils.py @@ -0,0 +1,71 @@ +from datahub.utilities.search_utils import ( + ElasticDocumentQuery, + LogicalOperator, + SearchField, +) + + +def test_simple_and_filters(): + query = ( + ElasticDocumentQuery.create_from() + .group(LogicalOperator.AND) + .add_field_match("field1", "value1") + .add_field_match("field2", "value2") + .end() + ) + + assert query.build() == '(field1:"value1" AND field2:"value2")' + + +def test_simple_or_filters(): + query = ( + ElasticDocumentQuery.create_from() + .group(LogicalOperator.OR) + .add_field_match("field1", "value1") + .add_field_match("field2", "value2") + .end() + ) + + assert query.build() == '(field1:"value1" OR field2:"value2")' + + # Use SearchFilter to create this query + query = ( + ElasticDocumentQuery.create_from() + .group(LogicalOperator.OR) + .add_field_match(SearchField.from_string_field("field1"), "value1") + .add_field_match(SearchField.from_string_field("field2"), "value2") + .end() + ) + assert query.build() == '(field1:"value1" OR field2:"value2")' + + +def test_simple_field_match(): + query: ElasticDocumentQuery = ElasticDocumentQuery.create_from( + ("field1", "value1:1") + ) + assert query.build() == 'field1:"value1\\:1"' + + # Another way to create the same query + query = ElasticDocumentQuery.create_from() + query.add_field_match("field1", "value1:1") + assert query.build() == 'field1:"value1\\:1"' + + +def test_negation(): + query = ( + ElasticDocumentQuery.create_from() + .group(LogicalOperator.AND) + .add_field_match("field1", "value1") + .add_field_not_match("field2", "value2") + .end() + ) + + assert query.build() == '(field1:"value1" AND -field2:"value2")' + + +def test_multi_arg_create_from(): + query: ElasticDocumentQuery = ElasticDocumentQuery.create_from( + ("field1", "value1"), + ("field2", "value2"), + ) + assert query.build() == '(field1:"value1" AND field2:"value2")' diff --git a/smoke-test/tests/platform_resources/test_platform_resource.py b/smoke-test/tests/platform_resources/test_platform_resource.py index 7ebfd4d6ea15b..39d15f2e8dea6 100644 --- a/smoke-test/tests/platform_resources/test_platform_resource.py +++ b/smoke-test/tests/platform_resources/test_platform_resource.py @@ -5,8 +5,10 @@ import pytest from datahub.api.entities.platformresource.platform_resource import ( + ElasticPlatformResourceQuery, PlatformResource, PlatformResourceKey, + PlatformResourceSearchFields, ) from tests.utils import wait_for_healthcheck_util, wait_for_writes_to_sync @@ -42,7 +44,12 @@ def cleanup_resources(graph_client): logger.warning(f"Failed to delete resource: {e}") # Additional cleanup for any resources that might have been missed - for resource in PlatformResource.search_by_key(graph_client, "test_"): + for resource in PlatformResource.search_by_filters( + graph_client, + ElasticPlatformResourceQuery.create_from().add_wildcard( + PlatformResourceSearchFields.PRIMARY_KEY, "test_*" + ), + ): try: resource.delete(graph_client) except Exception as e: @@ -114,7 +121,7 @@ def test_platform_resource_non_existent(graph_client, test_id): assert platform_resource is None -def test_platform_resource_urn_secondary_key(graph_client, test_id): +def test_platform_resource_urn_secondary_key(graph_client, test_id, cleanup_resources): key = PlatformResourceKey( platform=f"test_platform_{test_id}", resource_type=f"test_resource_type_{test_id}", @@ -129,6 +136,7 @@ def test_platform_resource_urn_secondary_key(graph_client, test_id): secondary_keys=[dataset_urn], ) platform_resource.to_datahub(graph_client) + cleanup_resources.append(platform_resource) wait_for_writes_to_sync() read_platform_resources = [ @@ -141,7 +149,9 @@ def test_platform_resource_urn_secondary_key(graph_client, test_id): assert read_platform_resources[0] == platform_resource -def test_platform_resource_listing_by_resource_type(graph_client, test_id): +def test_platform_resource_listing_by_resource_type( + graph_client, test_id, cleanup_resources +): # Generate two resources with the same resource type key1 = PlatformResourceKey( platform=f"test_platform_{test_id}", @@ -171,13 +181,9 @@ def test_platform_resource_listing_by_resource_type(graph_client, test_id): r for r in PlatformResource.search_by_filters( graph_client, - and_filters=[ - { - "field": "resourceType", - "condition": "EQUAL", - "value": key1.resource_type, - } - ], + query=ElasticPlatformResourceQuery.create_from( + (PlatformResourceSearchFields.RESOURCE_TYPE, key1.resource_type) + ), ) ] assert len(search_results) == 2 @@ -186,3 +192,55 @@ def test_platform_resource_listing_by_resource_type(graph_client, test_id): read_platform_resource_2 = next(r for r in search_results if r.id == key2.id) assert read_platform_resource_1 == platform_resource1 assert read_platform_resource_2 == platform_resource2 + + +def test_platform_resource_listing_complex_queries(graph_client, test_id): + # Generate two resources with the same resource type + key1 = PlatformResourceKey( + platform=f"test_platform1_{test_id}", + resource_type=f"test_resource_type_{test_id}", + primary_key=f"test_primary_key_1_{test_id}", + ) + platform_resource1 = PlatformResource.create( + key=key1, + value={"test_key": f"test_value_1_{test_id}"}, + ) + platform_resource1.to_datahub(graph_client) + + key2 = PlatformResourceKey( + platform=f"test_platform2_{test_id}", + resource_type=f"test_resource_type_{test_id}", + primary_key=f"test_primary_key_2_{test_id}", + ) + platform_resource2 = PlatformResource.create( + key=key2, + value={"test_key": f"test_value_2_{test_id}"}, + ) + platform_resource2.to_datahub(graph_client) + + wait_for_writes_to_sync() + from datahub.api.entities.platformresource.platform_resource import ( + ElasticPlatformResourceQuery, + LogicalOperator, + PlatformResourceSearchFields, + ) + + query = ( + ElasticPlatformResourceQuery.create_from() + .group(LogicalOperator.AND) + .add_field_match(PlatformResourceSearchFields.RESOURCE_TYPE, key1.resource_type) + .add_field_not_match(PlatformResourceSearchFields.PLATFORM, key1.platform) + .end() + ) + + search_results = [ + r + for r in PlatformResource.search_by_filters( + graph_client, + query=query, + ) + ] + assert len(search_results) == 1 + + read_platform_resource = search_results[0] + assert read_platform_resource == platform_resource2