From b33ad0a788b837783411f1539173d91071453fde Mon Sep 17 00:00:00 2001 From: Tamas Nemeth Date: Wed, 30 Oct 2024 17:41:45 +0100 Subject: [PATCH] feat(ingest/datahub): Add way to filter soft deleted entities (#11738) --- .../ingestion/source/datahub/config.py | 15 +- .../source/datahub/datahub_api_reader.py | 13 +- .../source/datahub/datahub_database_reader.py | 153 ++++++++++++------ .../source/datahub/datahub_kafka_reader.py | 5 + .../source/datahub/datahub_source.py | 42 +++-- .../ingestion/source/datahub/report.py | 2 + .../tests/unit/test_datahub_source.py | 52 +++--- 7 files changed, 190 insertions(+), 92 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/datahub/config.py b/metadata-ingestion/src/datahub/ingestion/source/datahub/config.py index 9705d63912b8d..a3304334cb1eb 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/datahub/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/datahub/config.py @@ -1,5 +1,5 @@ import os -from typing import Optional +from typing import Optional, Set from pydantic import Field, root_validator @@ -35,6 +35,19 @@ class DataHubSourceConfig(StatefulIngestionConfigBase): ), ) + include_soft_deleted_entities: bool = Field( + default=True, + description=( + "If enabled, include entities that have been soft deleted. " + "Otherwise, include all entities regardless of removal status. " + ), + ) + + exclude_aspects: Set[str] = Field( + default_factory=set, + description="Set of aspect names to exclude from ingestion", + ) + database_query_batch_size: int = Field( default=DEFAULT_DATABASE_BATCH_SIZE, description="Number of records to fetch from the database at a time", diff --git a/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_api_reader.py b/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_api_reader.py index 6986aac0a7757..382a0d548e38d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_api_reader.py +++ b/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_api_reader.py @@ -26,11 +26,17 @@ def __init__( self.report = report self.graph = graph - def get_aspects(self) -> Iterable[MetadataChangeProposalWrapper]: + def get_urns(self) -> Iterable[str]: urns = self.graph.get_urns_by_filter( - status=RemovedStatusFilter.ALL, + status=RemovedStatusFilter.ALL + if self.config.include_soft_deleted_entities + else RemovedStatusFilter.NOT_SOFT_DELETED, batch_size=self.config.database_query_batch_size, ) + return urns + + def get_aspects(self) -> Iterable[MetadataChangeProposalWrapper]: + urns = self.get_urns() tasks: List[futures.Future[Iterable[MetadataChangeProposalWrapper]]] = [] with futures.ThreadPoolExecutor( max_workers=self.config.max_workers @@ -43,6 +49,9 @@ def get_aspects(self) -> Iterable[MetadataChangeProposalWrapper]: def _get_aspects_for_urn(self, urn: str) -> Iterable[MetadataChangeProposalWrapper]: aspects: Dict[str, _Aspect] = self.graph.get_entity_semityped(urn) # type: ignore for aspect in aspects.values(): + if aspect.get_aspect_name().lower() in self.config.exclude_aspects: + continue + yield MetadataChangeProposalWrapper( entityUrn=urn, aspect=aspect, diff --git a/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_database_reader.py b/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_database_reader.py index e4f1bb275487e..faa281097de4c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_database_reader.py +++ b/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_database_reader.py @@ -1,11 +1,10 @@ +import contextlib import json import logging from datetime import datetime -from typing import Any, Generic, Iterable, List, Optional, Tuple, TypeVar +from typing import Any, Dict, Generic, Iterable, List, Optional, Tuple, TypeVar from sqlalchemy import create_engine -from sqlalchemy.engine import Row -from typing_extensions import Protocol from datahub.emitter.aspect import ASPECT_MAP from datahub.emitter.mcp import MetadataChangeProposalWrapper @@ -21,13 +20,7 @@ # Should work for at least mysql, mariadb, postgres DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S.%f" - -class VersionOrderable(Protocol): - createdon: Any # Should restrict to only orderable types - version: int - - -ROW = TypeVar("ROW", bound=VersionOrderable) +ROW = TypeVar("ROW", bound=Dict[str, Any]) class VersionOrderer(Generic[ROW]): @@ -54,14 +47,14 @@ def _process_row(self, row: ROW) -> Iterable[ROW]: return yield from self._attempt_queue_flush(row) - if row.version == 0: + if row["version"] == 0: self._add_to_queue(row) else: yield row def _add_to_queue(self, row: ROW) -> None: if self.queue is None: - self.queue = (row.createdon, [row]) + self.queue = (row["createdon"], [row]) else: self.queue[1].append(row) @@ -69,7 +62,7 @@ def _attempt_queue_flush(self, row: ROW) -> Iterable[ROW]: if self.queue is None: return - if row.createdon > self.queue[0]: + if row["createdon"] > self.queue[0]: yield from self._flush_queue() def _flush_queue(self) -> Iterable[ROW]: @@ -92,6 +85,21 @@ def __init__( **connection_config.options, ) + @property + def soft_deleted_urns_query(self) -> str: + return f""" + SELECT DISTINCT mav.urn + FROM {self.engine.dialect.identifier_preparer.quote(self.config.database_table_name)} as mav + JOIN ( + SELECT *, + JSON_EXTRACT(metadata, '$.removed') as removed + FROM {self.engine.dialect.identifier_preparer.quote(self.config.database_table_name)} + WHERE aspect = "status" AND version = 0 + ) as sd ON sd.urn = mav.urn + WHERE sd.removed = true + ORDER BY mav.urn + """ + @property def query(self) -> str: # May repeat rows for the same date @@ -101,66 +109,117 @@ def query(self) -> str: # Relies on createdon order to reflect version order # Ordering of entries with the same createdon is handled by VersionOrderer return f""" - SELECT urn, aspect, metadata, systemmetadata, createdon, version - FROM {self.engine.dialect.identifier_preparer.quote(self.config.database_table_name)} - WHERE createdon >= %(since_createdon)s - {"" if self.config.include_all_versions else "AND version = 0"} - ORDER BY createdon, urn, aspect, version - LIMIT %(limit)s - OFFSET %(offset)s + SELECT * + FROM ( + SELECT + mav.urn, + mav.aspect, + mav.metadata, + mav.systemmetadata, + mav.createdon, + mav.version, + removed + FROM {self.engine.dialect.identifier_preparer.quote(self.config.database_table_name)} as mav + LEFT JOIN ( + SELECT + *, + JSON_EXTRACT(metadata, '$.removed') as removed + FROM {self.engine.dialect.identifier_preparer.quote(self.config.database_table_name)} + WHERE aspect = 'status' + AND version = 0 + ) as sd ON sd.urn = mav.urn + WHERE 1 = 1 + {"" if self.config.include_all_versions else "AND mav.version = 0"} + {"" if not self.config.exclude_aspects else "AND mav.aspect NOT IN %(exclude_aspects)s"} + AND mav.createdon >= %(since_createdon)s + ORDER BY + createdon, + urn, + aspect, + version + ) as t + WHERE 1=1 + {"" if self.config.include_soft_deleted_entities else "AND (removed = false or removed is NULL)"} + ORDER BY + createdon, + urn, + aspect, + version """ def get_aspects( self, from_createdon: datetime, stop_time: datetime ) -> Iterable[Tuple[MetadataChangeProposalWrapper, datetime]]: - orderer = VersionOrderer[Row](enabled=self.config.include_all_versions) + orderer = VersionOrderer[Dict[str, Any]]( + enabled=self.config.include_all_versions + ) rows = self._get_rows(from_createdon=from_createdon, stop_time=stop_time) for row in orderer(rows): mcp = self._parse_row(row) if mcp: - yield mcp, row.createdon + yield mcp, row["createdon"] - def _get_rows(self, from_createdon: datetime, stop_time: datetime) -> Iterable[Row]: + def _get_rows( + self, from_createdon: datetime, stop_time: datetime + ) -> Iterable[Dict[str, Any]]: with self.engine.connect() as conn: - ts = from_createdon - offset = 0 - while ts.timestamp() <= stop_time.timestamp(): - logger.debug(f"Polling database aspects from {ts}") - rows = conn.execute( + with contextlib.closing(conn.connection.cursor()) as cursor: + cursor.execute( self.query, - since_createdon=ts.strftime(DATETIME_FORMAT), - limit=self.config.database_query_batch_size, - offset=offset, + { + "exclude_aspects": list(self.config.exclude_aspects), + "since_createdon": from_createdon.strftime(DATETIME_FORMAT), + }, ) - if not rows.rowcount: - return - for i, row in enumerate(rows): - yield row + columns = [desc[0] for desc in cursor.description] + while True: + rows = cursor.fetchmany(self.config.database_query_batch_size) + if not rows: + return + for row in rows: + yield dict(zip(columns, row)) - if ts == row.createdon: - offset += i + 1 - else: - ts = row.createdon - offset = 0 + def get_soft_deleted_rows(self) -> Iterable[Dict[str, Any]]: + """ + Fetches all soft-deleted entities from the database. - def _parse_row(self, row: Row) -> Optional[MetadataChangeProposalWrapper]: + Yields: + Row objects containing URNs of soft-deleted entities + """ + with self.engine.connect() as conn: + with contextlib.closing(conn.connection.cursor()) as cursor: + logger.debug("Polling soft-deleted urns from database") + cursor.execute(self.soft_deleted_urns_query) + columns = [desc[0] for desc in cursor.description] + while True: + rows = cursor.fetchmany(self.config.database_query_batch_size) + if not rows: + return + for row in rows: + yield dict(zip(columns, row)) + + def _parse_row( + self, row: Dict[str, Any] + ) -> Optional[MetadataChangeProposalWrapper]: try: - json_aspect = post_json_transform(json.loads(row.metadata)) - json_metadata = post_json_transform(json.loads(row.systemmetadata or "{}")) + json_aspect = post_json_transform(json.loads(row["metadata"])) + json_metadata = post_json_transform( + json.loads(row["systemmetadata"] or "{}") + ) system_metadata = SystemMetadataClass.from_obj(json_metadata) return MetadataChangeProposalWrapper( - entityUrn=row.urn, - aspect=ASPECT_MAP[row.aspect].from_obj(json_aspect), + entityUrn=row["urn"], + aspect=ASPECT_MAP[row["aspect"]].from_obj(json_aspect), systemMetadata=system_metadata, changeType=ChangeTypeClass.UPSERT, ) except Exception as e: logger.warning( - f"Failed to parse metadata for {row.urn}: {e}", exc_info=True + f'Failed to parse metadata for {row["urn"]}: {e}', exc_info=True ) self.report.num_database_parse_errors += 1 self.report.database_parse_errors.setdefault( str(e), LossyDict() - ).setdefault(row.aspect, LossyList()).append(row.urn) + ).setdefault(row["aspect"], LossyList()).append(row["urn"]) return None diff --git a/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_kafka_reader.py b/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_kafka_reader.py index d9e53e87c2cea..56a3d55abb184 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_kafka_reader.py +++ b/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_kafka_reader.py @@ -36,6 +36,7 @@ def __init__( self.connection_config = connection_config self.report = report self.group_id = f"{KAFKA_GROUP_PREFIX}-{ctx.pipeline_name}" + self.ctx = ctx def __enter__(self) -> "DataHubKafkaReader": self.consumer = DeserializingConsumer( @@ -95,6 +96,10 @@ def _poll_partition( ) break + if mcl.aspectName and mcl.aspectName in self.config.exclude_aspects: + self.report.num_kafka_excluded_aspects += 1 + continue + # TODO: Consider storing state in kafka instead, via consumer.commit() yield mcl, PartitionOffset(partition=msg.partition(), offset=msg.offset()) diff --git a/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_source.py b/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_source.py index 0204a864e2b9e..de212ca9a6771 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_source.py @@ -62,13 +62,18 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: self.report.stop_time = datetime.now(tz=timezone.utc) logger.info(f"Ingesting DataHub metadata up until {self.report.stop_time}") state = self.stateful_ingestion_handler.get_last_run_state() + database_reader: Optional[DataHubDatabaseReader] = None if self.config.pull_from_datahub_api: yield from self._get_api_workunits() if self.config.database_connection is not None: + database_reader = DataHubDatabaseReader( + self.config, self.config.database_connection, self.report + ) + yield from self._get_database_workunits( - from_createdon=state.database_createdon_datetime + from_createdon=state.database_createdon_datetime, reader=database_reader ) self._commit_progress() else: @@ -77,7 +82,19 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: ) if self.config.kafka_connection is not None: - yield from self._get_kafka_workunits(from_offsets=state.kafka_offsets) + soft_deleted_urns = [] + if not self.config.include_soft_deleted_entities: + if database_reader is None: + raise ValueError( + "Cannot exclude soft deleted entities without a database connection" + ) + soft_deleted_urns = [ + row["urn"] for row in database_reader.get_soft_deleted_rows() + ] + + yield from self._get_kafka_workunits( + from_offsets=state.kafka_offsets, soft_deleted_urns=soft_deleted_urns + ) self._commit_progress() else: logger.info( @@ -85,15 +102,9 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: ) def _get_database_workunits( - self, from_createdon: datetime + self, from_createdon: datetime, reader: DataHubDatabaseReader ) -> Iterable[MetadataWorkUnit]: - if self.config.database_connection is None: - return - logger.info(f"Fetching database aspects starting from {from_createdon}") - reader = DataHubDatabaseReader( - self.config, self.config.database_connection, self.report - ) mcps = reader.get_aspects(from_createdon, self.report.stop_time) for i, (mcp, createdon) in enumerate(mcps): @@ -113,20 +124,29 @@ def _get_database_workunits( self._commit_progress(i) def _get_kafka_workunits( - self, from_offsets: Dict[int, int] + self, from_offsets: Dict[int, int], soft_deleted_urns: List[str] = [] ) -> Iterable[MetadataWorkUnit]: if self.config.kafka_connection is None: return logger.info("Fetching timeseries aspects from kafka") with DataHubKafkaReader( - self.config, self.config.kafka_connection, self.report, self.ctx + self.config, + self.config.kafka_connection, + self.report, + self.ctx, ) as reader: mcls = reader.get_mcls( from_offsets=from_offsets, stop_time=self.report.stop_time ) for i, (mcl, offset) in enumerate(mcls): mcp = MetadataChangeProposalWrapper.try_from_mcl(mcl) + if mcp.entityUrn in soft_deleted_urns: + self.report.num_timeseries_soft_deleted_aspects_dropped += 1 + logger.debug( + f"Dropping soft-deleted aspect of {mcp.aspectName} on {mcp.entityUrn}" + ) + continue if mcp.changeType == ChangeTypeClass.DELETE: self.report.num_timeseries_deletions_dropped += 1 logger.debug( diff --git a/metadata-ingestion/src/datahub/ingestion/source/datahub/report.py b/metadata-ingestion/src/datahub/ingestion/source/datahub/report.py index 73e5a798a1553..721fc87989442 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/datahub/report.py +++ b/metadata-ingestion/src/datahub/ingestion/source/datahub/report.py @@ -20,6 +20,8 @@ class DataHubSourceReport(StatefulIngestionReport): num_kafka_aspects_ingested: int = 0 num_kafka_parse_errors: int = 0 + num_kafka_excluded_aspects: int = 0 kafka_parse_errors: LossyDict[str, int] = field(default_factory=LossyDict) num_timeseries_deletions_dropped: int = 0 + num_timeseries_soft_deleted_aspects_dropped: int = 0 diff --git a/metadata-ingestion/tests/unit/test_datahub_source.py b/metadata-ingestion/tests/unit/test_datahub_source.py index adc131362b326..67b2b85d9af6d 100644 --- a/metadata-ingestion/tests/unit/test_datahub_source.py +++ b/metadata-ingestion/tests/unit/test_datahub_source.py @@ -1,51 +1,41 @@ -from dataclasses import dataclass +from typing import Any, Dict import pytest -from datahub.ingestion.source.datahub.datahub_database_reader import ( - VersionOrderable, - VersionOrderer, -) - - -@dataclass -class MockRow(VersionOrderable): - createdon: int - version: int - urn: str +from datahub.ingestion.source.datahub.datahub_database_reader import VersionOrderer @pytest.fixture def rows(): return [ - MockRow(0, 0, "one"), - MockRow(0, 1, "one"), - MockRow(0, 0, "two"), - MockRow(0, 0, "three"), - MockRow(0, 1, "three"), - MockRow(0, 2, "three"), - MockRow(0, 1, "two"), - MockRow(0, 4, "three"), - MockRow(0, 5, "three"), - MockRow(1, 6, "three"), - MockRow(1, 0, "four"), - MockRow(2, 0, "five"), - MockRow(2, 1, "six"), - MockRow(2, 0, "six"), - MockRow(3, 0, "seven"), - MockRow(3, 0, "eight"), + {"createdon": 0, "version": 0, "urn": "one"}, + {"createdon": 0, "version": 1, "urn": "one"}, + {"createdon": 0, "version": 0, "urn": "two"}, + {"createdon": 0, "version": 0, "urn": "three"}, + {"createdon": 0, "version": 1, "urn": "three"}, + {"createdon": 0, "version": 2, "urn": "three"}, + {"createdon": 0, "version": 1, "urn": "two"}, + {"createdon": 0, "version": 4, "urn": "three"}, + {"createdon": 0, "version": 5, "urn": "three"}, + {"createdon": 1, "version": 6, "urn": "three"}, + {"createdon": 1, "version": 0, "urn": "four"}, + {"createdon": 2, "version": 0, "urn": "five"}, + {"createdon": 2, "version": 1, "urn": "six"}, + {"createdon": 2, "version": 0, "urn": "six"}, + {"createdon": 3, "version": 0, "urn": "seven"}, + {"createdon": 3, "version": 0, "urn": "eight"}, ] def test_version_orderer(rows): - orderer = VersionOrderer[MockRow](enabled=True) + orderer = VersionOrderer[Dict[str, Any]](enabled=True) ordered_rows = list(orderer(rows)) assert ordered_rows == sorted( - ordered_rows, key=lambda x: (x.createdon, x.version == 0) + ordered_rows, key=lambda x: (x["createdon"], x["version"] == 0) ) def test_version_orderer_disabled(rows): - orderer = VersionOrderer[MockRow](enabled=False) + orderer = VersionOrderer[Dict[str, Any]](enabled=False) ordered_rows = list(orderer(rows)) assert ordered_rows == rows