Skip to content

Commit

Permalink
feat(ingest/datahub): Add way to filter soft deleted entities (#11738)
Browse files Browse the repository at this point in the history
  • Loading branch information
treff7es authored Oct 30, 2024
1 parent c870450 commit b33ad0a
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 92 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional
from typing import Optional, Set

from pydantic import Field, root_validator

Expand Down Expand Up @@ -35,6 +35,19 @@ class DataHubSourceConfig(StatefulIngestionConfigBase):
),
)

include_soft_deleted_entities: bool = Field(
default=True,
description=(
"If enabled, include entities that have been soft deleted. "
"Otherwise, include all entities regardless of removal status. "
),
)

exclude_aspects: Set[str] = Field(
default_factory=set,
description="Set of aspect names to exclude from ingestion",
)

database_query_batch_size: int = Field(
default=DEFAULT_DATABASE_BATCH_SIZE,
description="Number of records to fetch from the database at a time",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,17 @@ def __init__(
self.report = report
self.graph = graph

def get_aspects(self) -> Iterable[MetadataChangeProposalWrapper]:
def get_urns(self) -> Iterable[str]:
urns = self.graph.get_urns_by_filter(
status=RemovedStatusFilter.ALL,
status=RemovedStatusFilter.ALL
if self.config.include_soft_deleted_entities
else RemovedStatusFilter.NOT_SOFT_DELETED,
batch_size=self.config.database_query_batch_size,
)
return urns

def get_aspects(self) -> Iterable[MetadataChangeProposalWrapper]:
urns = self.get_urns()
tasks: List[futures.Future[Iterable[MetadataChangeProposalWrapper]]] = []
with futures.ThreadPoolExecutor(
max_workers=self.config.max_workers
Expand All @@ -43,6 +49,9 @@ def get_aspects(self) -> Iterable[MetadataChangeProposalWrapper]:
def _get_aspects_for_urn(self, urn: str) -> Iterable[MetadataChangeProposalWrapper]:
aspects: Dict[str, _Aspect] = self.graph.get_entity_semityped(urn) # type: ignore
for aspect in aspects.values():
if aspect.get_aspect_name().lower() in self.config.exclude_aspects:
continue

yield MetadataChangeProposalWrapper(
entityUrn=urn,
aspect=aspect,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import contextlib
import json
import logging
from datetime import datetime
from typing import Any, Generic, Iterable, List, Optional, Tuple, TypeVar
from typing import Any, Dict, Generic, Iterable, List, Optional, Tuple, TypeVar

from sqlalchemy import create_engine
from sqlalchemy.engine import Row
from typing_extensions import Protocol

from datahub.emitter.aspect import ASPECT_MAP
from datahub.emitter.mcp import MetadataChangeProposalWrapper
Expand All @@ -21,13 +20,7 @@
# Should work for at least mysql, mariadb, postgres
DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S.%f"


class VersionOrderable(Protocol):
createdon: Any # Should restrict to only orderable types
version: int


ROW = TypeVar("ROW", bound=VersionOrderable)
ROW = TypeVar("ROW", bound=Dict[str, Any])


class VersionOrderer(Generic[ROW]):
Expand All @@ -54,22 +47,22 @@ def _process_row(self, row: ROW) -> Iterable[ROW]:
return

yield from self._attempt_queue_flush(row)
if row.version == 0:
if row["version"] == 0:
self._add_to_queue(row)
else:
yield row

def _add_to_queue(self, row: ROW) -> None:
if self.queue is None:
self.queue = (row.createdon, [row])
self.queue = (row["createdon"], [row])
else:
self.queue[1].append(row)

def _attempt_queue_flush(self, row: ROW) -> Iterable[ROW]:
if self.queue is None:
return

if row.createdon > self.queue[0]:
if row["createdon"] > self.queue[0]:
yield from self._flush_queue()

def _flush_queue(self) -> Iterable[ROW]:
Expand All @@ -92,6 +85,21 @@ def __init__(
**connection_config.options,
)

@property
def soft_deleted_urns_query(self) -> str:
return f"""
SELECT DISTINCT mav.urn
FROM {self.engine.dialect.identifier_preparer.quote(self.config.database_table_name)} as mav
JOIN (
SELECT *,
JSON_EXTRACT(metadata, '$.removed') as removed
FROM {self.engine.dialect.identifier_preparer.quote(self.config.database_table_name)}
WHERE aspect = "status" AND version = 0
) as sd ON sd.urn = mav.urn
WHERE sd.removed = true
ORDER BY mav.urn
"""

@property
def query(self) -> str:
# May repeat rows for the same date
Expand All @@ -101,66 +109,117 @@ def query(self) -> str:
# Relies on createdon order to reflect version order
# Ordering of entries with the same createdon is handled by VersionOrderer
return f"""
SELECT urn, aspect, metadata, systemmetadata, createdon, version
FROM {self.engine.dialect.identifier_preparer.quote(self.config.database_table_name)}
WHERE createdon >= %(since_createdon)s
{"" if self.config.include_all_versions else "AND version = 0"}
ORDER BY createdon, urn, aspect, version
LIMIT %(limit)s
OFFSET %(offset)s
SELECT *
FROM (
SELECT
mav.urn,
mav.aspect,
mav.metadata,
mav.systemmetadata,
mav.createdon,
mav.version,
removed
FROM {self.engine.dialect.identifier_preparer.quote(self.config.database_table_name)} as mav
LEFT JOIN (
SELECT
*,
JSON_EXTRACT(metadata, '$.removed') as removed
FROM {self.engine.dialect.identifier_preparer.quote(self.config.database_table_name)}
WHERE aspect = 'status'
AND version = 0
) as sd ON sd.urn = mav.urn
WHERE 1 = 1
{"" if self.config.include_all_versions else "AND mav.version = 0"}
{"" if not self.config.exclude_aspects else "AND mav.aspect NOT IN %(exclude_aspects)s"}
AND mav.createdon >= %(since_createdon)s
ORDER BY
createdon,
urn,
aspect,
version
) as t
WHERE 1=1
{"" if self.config.include_soft_deleted_entities else "AND (removed = false or removed is NULL)"}
ORDER BY
createdon,
urn,
aspect,
version
"""

def get_aspects(
self, from_createdon: datetime, stop_time: datetime
) -> Iterable[Tuple[MetadataChangeProposalWrapper, datetime]]:
orderer = VersionOrderer[Row](enabled=self.config.include_all_versions)
orderer = VersionOrderer[Dict[str, Any]](
enabled=self.config.include_all_versions
)
rows = self._get_rows(from_createdon=from_createdon, stop_time=stop_time)
for row in orderer(rows):
mcp = self._parse_row(row)
if mcp:
yield mcp, row.createdon
yield mcp, row["createdon"]

def _get_rows(self, from_createdon: datetime, stop_time: datetime) -> Iterable[Row]:
def _get_rows(
self, from_createdon: datetime, stop_time: datetime
) -> Iterable[Dict[str, Any]]:
with self.engine.connect() as conn:
ts = from_createdon
offset = 0
while ts.timestamp() <= stop_time.timestamp():
logger.debug(f"Polling database aspects from {ts}")
rows = conn.execute(
with contextlib.closing(conn.connection.cursor()) as cursor:
cursor.execute(
self.query,
since_createdon=ts.strftime(DATETIME_FORMAT),
limit=self.config.database_query_batch_size,
offset=offset,
{
"exclude_aspects": list(self.config.exclude_aspects),
"since_createdon": from_createdon.strftime(DATETIME_FORMAT),
},
)
if not rows.rowcount:
return

for i, row in enumerate(rows):
yield row
columns = [desc[0] for desc in cursor.description]
while True:
rows = cursor.fetchmany(self.config.database_query_batch_size)
if not rows:
return
for row in rows:
yield dict(zip(columns, row))

if ts == row.createdon:
offset += i + 1
else:
ts = row.createdon
offset = 0
def get_soft_deleted_rows(self) -> Iterable[Dict[str, Any]]:
"""
Fetches all soft-deleted entities from the database.
def _parse_row(self, row: Row) -> Optional[MetadataChangeProposalWrapper]:
Yields:
Row objects containing URNs of soft-deleted entities
"""
with self.engine.connect() as conn:
with contextlib.closing(conn.connection.cursor()) as cursor:
logger.debug("Polling soft-deleted urns from database")
cursor.execute(self.soft_deleted_urns_query)
columns = [desc[0] for desc in cursor.description]
while True:
rows = cursor.fetchmany(self.config.database_query_batch_size)
if not rows:
return
for row in rows:
yield dict(zip(columns, row))

def _parse_row(
self, row: Dict[str, Any]
) -> Optional[MetadataChangeProposalWrapper]:
try:
json_aspect = post_json_transform(json.loads(row.metadata))
json_metadata = post_json_transform(json.loads(row.systemmetadata or "{}"))
json_aspect = post_json_transform(json.loads(row["metadata"]))
json_metadata = post_json_transform(
json.loads(row["systemmetadata"] or "{}")
)
system_metadata = SystemMetadataClass.from_obj(json_metadata)
return MetadataChangeProposalWrapper(
entityUrn=row.urn,
aspect=ASPECT_MAP[row.aspect].from_obj(json_aspect),
entityUrn=row["urn"],
aspect=ASPECT_MAP[row["aspect"]].from_obj(json_aspect),
systemMetadata=system_metadata,
changeType=ChangeTypeClass.UPSERT,
)
except Exception as e:
logger.warning(
f"Failed to parse metadata for {row.urn}: {e}", exc_info=True
f'Failed to parse metadata for {row["urn"]}: {e}', exc_info=True
)
self.report.num_database_parse_errors += 1
self.report.database_parse_errors.setdefault(
str(e), LossyDict()
).setdefault(row.aspect, LossyList()).append(row.urn)
).setdefault(row["aspect"], LossyList()).append(row["urn"])
return None
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
self.connection_config = connection_config
self.report = report
self.group_id = f"{KAFKA_GROUP_PREFIX}-{ctx.pipeline_name}"
self.ctx = ctx

def __enter__(self) -> "DataHubKafkaReader":
self.consumer = DeserializingConsumer(
Expand Down Expand Up @@ -95,6 +96,10 @@ def _poll_partition(
)
break

if mcl.aspectName and mcl.aspectName in self.config.exclude_aspects:
self.report.num_kafka_excluded_aspects += 1
continue

# TODO: Consider storing state in kafka instead, via consumer.commit()
yield mcl, PartitionOffset(partition=msg.partition(), offset=msg.offset())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,18 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
self.report.stop_time = datetime.now(tz=timezone.utc)
logger.info(f"Ingesting DataHub metadata up until {self.report.stop_time}")
state = self.stateful_ingestion_handler.get_last_run_state()
database_reader: Optional[DataHubDatabaseReader] = None

if self.config.pull_from_datahub_api:
yield from self._get_api_workunits()

if self.config.database_connection is not None:
database_reader = DataHubDatabaseReader(
self.config, self.config.database_connection, self.report
)

yield from self._get_database_workunits(
from_createdon=state.database_createdon_datetime
from_createdon=state.database_createdon_datetime, reader=database_reader
)
self._commit_progress()
else:
Expand All @@ -77,23 +82,29 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
)

if self.config.kafka_connection is not None:
yield from self._get_kafka_workunits(from_offsets=state.kafka_offsets)
soft_deleted_urns = []
if not self.config.include_soft_deleted_entities:
if database_reader is None:
raise ValueError(
"Cannot exclude soft deleted entities without a database connection"
)
soft_deleted_urns = [
row["urn"] for row in database_reader.get_soft_deleted_rows()
]

yield from self._get_kafka_workunits(
from_offsets=state.kafka_offsets, soft_deleted_urns=soft_deleted_urns
)
self._commit_progress()
else:
logger.info(
"Skipping ingestion of timeseries aspects as no kafka_connection provided"
)

def _get_database_workunits(
self, from_createdon: datetime
self, from_createdon: datetime, reader: DataHubDatabaseReader
) -> Iterable[MetadataWorkUnit]:
if self.config.database_connection is None:
return

logger.info(f"Fetching database aspects starting from {from_createdon}")
reader = DataHubDatabaseReader(
self.config, self.config.database_connection, self.report
)
mcps = reader.get_aspects(from_createdon, self.report.stop_time)
for i, (mcp, createdon) in enumerate(mcps):

Expand All @@ -113,20 +124,29 @@ def _get_database_workunits(
self._commit_progress(i)

def _get_kafka_workunits(
self, from_offsets: Dict[int, int]
self, from_offsets: Dict[int, int], soft_deleted_urns: List[str] = []
) -> Iterable[MetadataWorkUnit]:
if self.config.kafka_connection is None:
return

logger.info("Fetching timeseries aspects from kafka")
with DataHubKafkaReader(
self.config, self.config.kafka_connection, self.report, self.ctx
self.config,
self.config.kafka_connection,
self.report,
self.ctx,
) as reader:
mcls = reader.get_mcls(
from_offsets=from_offsets, stop_time=self.report.stop_time
)
for i, (mcl, offset) in enumerate(mcls):
mcp = MetadataChangeProposalWrapper.try_from_mcl(mcl)
if mcp.entityUrn in soft_deleted_urns:
self.report.num_timeseries_soft_deleted_aspects_dropped += 1
logger.debug(
f"Dropping soft-deleted aspect of {mcp.aspectName} on {mcp.entityUrn}"
)
continue
if mcp.changeType == ChangeTypeClass.DELETE:
self.report.num_timeseries_deletions_dropped += 1
logger.debug(
Expand Down
Loading

0 comments on commit b33ad0a

Please sign in to comment.