Skip to content

Commit

Permalink
feat(sdk):platform-resource - complex queries (datahub-project#11675)
Browse files Browse the repository at this point in the history
  • Loading branch information
shirshanka authored and keith-fullsight committed Oct 21, 2024
1 parent 12abda4 commit bda79bd
Show file tree
Hide file tree
Showing 5 changed files with 508 additions and 10 deletions.
69 changes: 69 additions & 0 deletions metadata-ingestion/src/datahub/utilities/openapi_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import logging
from typing import Iterable, Union

import datahub.metadata.schema_classes as models
from datahub.ingestion.graph.client import DataHubGraph
from datahub.utilities.search_utils import (
ElasticDocumentQuery,
ElasticsearchQueryBuilder,
)

logger = logging.getLogger(__name__)


class OpenAPIGraphClient:
"""
An experimental client for the DataHubGraph that uses the OpenAPI endpoints
to query entities and aspects.
Does not support all features of the DataHubGraph.
API is subject to change.
DO NOT USE THIS UNLESS YOU KNOW WHAT YOU ARE DOING.
"""

ENTITY_KEY_ASPECT_MAP = {
aspect_type.ASPECT_INFO.get("keyForEntity"): name
for name, aspect_type in models.ASPECT_NAME_MAP.items()
if aspect_type.ASPECT_INFO.get("keyForEntity")
}

def __init__(self, graph: DataHubGraph):
self.graph = graph
self.openapi_base = graph._gms_server.rstrip("/") + "/openapi/v3"

def scroll_urns_by_filter(
self,
entity_type: str,
query: Union[ElasticDocumentQuery, ElasticsearchQueryBuilder],
) -> Iterable[str]:
"""
Scroll through all urns that match the given filters.
"""

key_aspect = self.ENTITY_KEY_ASPECT_MAP.get(entity_type)
assert key_aspect, f"No key aspect found for entity type {entity_type}"

count = 1000
string_query = query.build()
scroll_id = None
logger.debug(f"Scrolling with query: {string_query}")
while True:
response = self.graph._get_generic(
self.openapi_base + f"/entity/{entity_type.lower()}",
params={
"systemMetadata": "false",
"includeSoftDelete": "false",
"skipCache": "false",
"aspects": [key_aspect],
"scrollId": scroll_id,
"count": count,
"query": string_query,
},
)
entities = response.get("entities", [])
scroll_id = response.get("scrollId")
for entity in entities:
yield entity["urn"]
if not scroll_id:
break
285 changes: 285 additions & 0 deletions metadata-ingestion/src/datahub/utilities/search_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
import logging
import re
from enum import Enum
from typing import Generic, List, Optional, Tuple, Type, TypeVar, Union

logger = logging.getLogger(__name__)


class LogicalOperator(Enum):
AND = "AND"
OR = "OR"


class SearchField:
def __init__(self, field_name: str):
self.field_name = field_name

def get_search_value(self, value: str) -> str:
return value

def __str__(self) -> str:
return self.field_name

def __repr__(self) -> str:
return self.__str__()

@classmethod
def from_string_field(cls, field_name: str) -> "SearchField":
return cls(field_name)


class QueryNode:
def __init__(self, operator: Optional[LogicalOperator] = None):
self.operator = operator
self.children: List[Union[QueryNode, str]] = []

def add_child(self, child: Union["QueryNode", str]) -> None:
self.children.append(child)

def build(self) -> str:
if not self.children:
return ""

if self.operator is None:
return (
self.children[0]
if isinstance(self.children[0], str)
else self.children[0].build()
)

child_queries = []
for child in self.children:
if isinstance(child, str):
child_queries.append(child)
else:
child_queries.append(child.build())

joined_queries = f" {self.operator.value} ".join(child_queries)
return f"({joined_queries})" if len(child_queries) > 1 else joined_queries


class ElasticsearchQueryBuilder:
SPECIAL_CHARACTERS = r'+-=&|><!(){}[]^"~*?:\/'

def __init__(self, operator: LogicalOperator = LogicalOperator.AND) -> None:
self.root = QueryNode(operator=operator)

@classmethod
def escape_special_characters(cls, value: str) -> str:
"""
Escape special characters in the search term.
"""
return re.sub(f"([{re.escape(cls.SPECIAL_CHARACTERS)}])", r"\\\1", value)

def _create_term(
self, field: SearchField, value: str, is_exact: bool = False
) -> str:
escaped_value = self.escape_special_characters(field.get_search_value(value))
field_name: str = field.field_name
if is_exact:
return f'{field_name}:"{escaped_value}"'
return f"{field_name}:{escaped_value}"

def add_field_match(
self, field: SearchField, value: str, is_exact: bool = True
) -> "ElasticsearchQueryBuilder":
term = self._create_term(field, value, is_exact)
self.root.add_child(term)
return self

def add_field_not_match(
self, field: SearchField, value: str, is_exact: bool = True
) -> "ElasticsearchQueryBuilder":
term = f"-{self._create_term(field, value, is_exact)}"
self.root.add_child(term)
return self

def add_range(
self,
field: str,
min_value: Optional[str] = None,
max_value: Optional[str] = None,
include_min: bool = True,
include_max: bool = True,
) -> "ElasticsearchQueryBuilder":
min_bracket = "[" if include_min else "{"
max_bracket = "]" if include_max else "}"
min_val = min_value if min_value is not None else "*"
max_val = max_value if max_value is not None else "*"
range_query = f"{field}:{min_bracket}{min_val} TO {max_val}{max_bracket}"
self.root.add_child(range_query)
return self

def add_wildcard(self, field: str, pattern: str) -> "ElasticsearchQueryBuilder":
wildcard_query = f"{field}:{pattern}"
self.root.add_child(wildcard_query)
return self

def add_fuzzy(
self, field: str, value: str, fuzziness: int = 2
) -> "ElasticsearchQueryBuilder":
fuzzy_query = f"{field}:{value}~{fuzziness}"
self.root.add_child(fuzzy_query)
return self

def add_boost(
self, field: str, value: str, boost: float
) -> "ElasticsearchQueryBuilder":
boosted_query = f"{field}:{value}^{boost}"
self.root.add_child(boosted_query)
return self

def group(self, operator: LogicalOperator) -> "QueryGroup":
return QueryGroup(self, operator)

def build(self) -> str:
return self.root.build()


class QueryGroup:
def __init__(self, parent: ElasticsearchQueryBuilder, operator: LogicalOperator):
self.parent = parent
self.node = QueryNode(operator)
self.parent.root.add_child(self.node)

def add_field_match(
self, field: Union[str, SearchField], value: str, is_exact: bool = True
) -> "QueryGroup":
if isinstance(field, str):
field = SearchField.from_string_field(field)
term = self.parent._create_term(field, value, is_exact)
self.node.add_child(term)
return self

def add_field_not_match(
self, field: Union[str, SearchField], value: str, is_exact: bool = True
) -> "QueryGroup":
if isinstance(field, str):
field = SearchField.from_string_field(field)
term = f"-{self.parent._create_term(field, value, is_exact)}"
self.node.add_child(term)
return self

def add_range(
self,
field: str,
min_value: Optional[str] = None,
max_value: Optional[str] = None,
include_min: bool = True,
include_max: bool = True,
) -> "QueryGroup":
min_bracket = "[" if include_min else "{"
max_bracket = "]" if include_max else "}"
min_val = min_value if min_value is not None else "*"
max_val = max_value if max_value is not None else "*"
range_query = f"{field}:{min_bracket}{min_val} TO {max_val}{max_bracket}"
self.node.add_child(range_query)
return self

def add_wildcard(self, field: str, pattern: str) -> "QueryGroup":
wildcard_query = f"{field}:{pattern}"
self.node.add_child(wildcard_query)
return self

def add_fuzzy(self, field: str, value: str, fuzziness: int = 2) -> "QueryGroup":
fuzzy_query = f"{field}:{value}~{fuzziness}"
self.node.add_child(fuzzy_query)
return self

def add_boost(self, field: str, value: str, boost: float) -> "QueryGroup":
boosted_query = f"{field}:{value}^{boost}"
self.node.add_child(boosted_query)
return self

def group(self, operator: LogicalOperator) -> "QueryGroup":
new_group = QueryGroup(self.parent, operator)
self.node.add_child(new_group.node)
return new_group

def end(self) -> ElasticsearchQueryBuilder:
return self.parent


SF = TypeVar("SF", bound=SearchField)


class ElasticDocumentQuery(Generic[SF]):
def __init__(self) -> None:
self.query_builder = ElasticsearchQueryBuilder()

@classmethod
def create_from(
cls: Type["ElasticDocumentQuery[SF]"],
*args: Tuple[Union[str, SF], str],
) -> "ElasticDocumentQuery[SF]":
instance = cls()
for arg in args:
if isinstance(arg, SearchField):
# If the value is empty, we treat it as a wildcard search
logger.info(f"Adding wildcard search for field {arg}")
instance.add_wildcard(arg, "*")
elif isinstance(arg, tuple) and len(arg) == 2:
field, value = arg
assert isinstance(value, str)
if isinstance(field, SearchField):
instance.add_field_match(field, value)
elif isinstance(field, str):
instance.add_field_match(
SearchField.from_string_field(field), value
)
else:
raise ValueError("Invalid field type {}".format(type(field)))
return instance

def add_field_match(
self, field: Union[str, SearchField], value: str, is_exact: bool = True
) -> "ElasticDocumentQuery":
if isinstance(field, str):
field = SearchField.from_string_field(field)
self.query_builder.add_field_match(field, value, is_exact)
return self

def add_field_not_match(
self, field: SearchField, value: str, is_exact: bool = True
) -> "ElasticDocumentQuery":
self.query_builder.add_field_not_match(field, value, is_exact)
return self

def add_range(
self,
field: SearchField,
min_value: Optional[str] = None,
max_value: Optional[str] = None,
include_min: bool = True,
include_max: bool = True,
) -> "ElasticDocumentQuery":
field_name: str = field.field_name # type: ignore
self.query_builder.add_range(
field_name, min_value, max_value, include_min, include_max
)
return self

def add_wildcard(self, field: SearchField, pattern: str) -> "ElasticDocumentQuery":
field_name: str = field.field_name # type: ignore
self.query_builder.add_wildcard(field_name, pattern)
return self

def add_fuzzy(
self, field: SearchField, value: str, fuzziness: int = 2
) -> "ElasticDocumentQuery":
field_name: str = field.field_name # type: ignore
self.query_builder.add_fuzzy(field_name, value, fuzziness)
return self

def add_boost(
self, field: SearchField, value: str, boost: float
) -> "ElasticDocumentQuery":
self.query_builder.add_boost(field.field_name, value, boost)
return self

def group(self, operator: LogicalOperator) -> QueryGroup:
return self.query_builder.group(operator)

def build(self) -> str:
return self.query_builder.build()
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@

import datahub.metadata.schema_classes as models
from datahub.api.entities.platformresource.platform_resource import (
ElasticPlatformResourceQuery,
PlatformResource,
PlatformResourceKey,
PlatformResourceSearchFields,
)
from datahub.utilities.search_utils import LogicalOperator


def test_platform_resource_dict():
Expand Down Expand Up @@ -179,3 +182,15 @@ class TestModel(BaseModel):
).encode("utf-8")
assert platform_resource_info_mcp.aspect.value.schemaType == "JSON"
assert platform_resource_info_mcp.aspect.value.schemaRef == TestModel.__name__


def test_platform_resource_filters():

query = (
ElasticPlatformResourceQuery.create_from()
.group(LogicalOperator.AND)
.add_field_match(PlatformResourceSearchFields.PRIMARY_KEY, "test_1")
.add_field_match(PlatformResourceSearchFields.RESOURCE_TYPE, "server")
.end()
)
assert query.build() == '(primaryKey:"test_1" AND resourceType:"server")'
Loading

0 comments on commit bda79bd

Please sign in to comment.