Skip to content

Commit

Permalink
perf(ingest): streamline CLL generation (datahub-project#11645)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored and aviv-julienjehannet committed Oct 21, 2024
1 parent f111501 commit aed3aa2
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 28 deletions.
2 changes: 1 addition & 1 deletion metadata-ingestion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
sqlglot_lib = {
# Using an Acryl fork of sqlglot.
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:main?expand=1
"acryl-sqlglot[rs]==25.20.2.dev6",
"acryl-sqlglot[rs]==25.25.2.dev9",
}

classification_lib = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def parse_alter_table_rename(default_schema: str, query: str) -> Tuple[str, str,
assert isinstance(parsed_query, sqlglot.exp.Alter)
prev_name = parsed_query.this.name
rename_clause = parsed_query.args["actions"][0]
assert isinstance(rename_clause, sqlglot.exp.RenameTable)
assert isinstance(rename_clause, sqlglot.exp.AlterRename)
new_name = rename_clause.this.name

schema = parsed_query.this.db or default_schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2131,7 +2131,7 @@ def _create_lineage_from_unsupported_csql(

fine_grained_lineages: List[FineGrainedLineage] = []
if self.config.extract_column_level_lineage:
logger.info("Extracting CLL from custom sql")
logger.debug("Extracting CLL from custom sql")
fine_grained_lineages = make_fine_grained_lineage_class(
parsed_result, csql_urn, out_columns
)
Expand Down
51 changes: 32 additions & 19 deletions metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import dataclasses
import functools
import itertools
import logging
import traceback
from collections import defaultdict
Expand All @@ -14,6 +13,8 @@
import sqlglot.optimizer.annotate_types
import sqlglot.optimizer.optimizer
import sqlglot.optimizer.qualify
import sqlglot.optimizer.qualify_columns
import sqlglot.optimizer.unnest_subqueries

from datahub.cli.env_utils import get_boolean_env_variable
from datahub.ingestion.graph.client import DataHubGraph
Expand Down Expand Up @@ -63,24 +64,30 @@
SQL_LINEAGE_TIMEOUT_SECONDS = 10


RULES_BEFORE_TYPE_ANNOTATION: tuple = tuple(
filter(
lambda func: func.__name__
not in {
# Skip pushdown_predicates because it sometimes throws exceptions, and we
# don't actually need it for anything.
"pushdown_predicates",
# Skip normalize because it can sometimes be expensive.
"normalize",
},
itertools.takewhile(
lambda func: func != sqlglot.optimizer.annotate_types.annotate_types,
sqlglot.optimizer.optimizer.RULES,
),
)
# These rules are a subset of the rules in sqlglot.optimizer.optimizer.RULES.
# If there's a change in their rules, we probably need to re-evaluate our list as well.
assert len(sqlglot.optimizer.optimizer.RULES) == 14

_OPTIMIZE_RULES = (
sqlglot.optimizer.optimizer.qualify,
# We need to enable this in order for annotate types to work.
sqlglot.optimizer.optimizer.pushdown_projections,
# sqlglot.optimizer.optimizer.normalize, # causes perf issues
sqlglot.optimizer.optimizer.unnest_subqueries,
# sqlglot.optimizer.optimizer.pushdown_predicates, # causes perf issues
# sqlglot.optimizer.optimizer.optimize_joins,
# sqlglot.optimizer.optimizer.eliminate_subqueries,
# sqlglot.optimizer.optimizer.merge_subqueries,
# sqlglot.optimizer.optimizer.eliminate_joins,
# sqlglot.optimizer.optimizer.eliminate_ctes,
sqlglot.optimizer.optimizer.quote_identifiers,
# These three are run separately or not run at all.
# sqlglot.optimizer.optimizer.annotate_types,
# sqlglot.optimizer.canonicalize.canonicalize,
# sqlglot.optimizer.simplify.simplify,
)
# Quick check that the rules were loaded correctly.
assert 0 < len(RULES_BEFORE_TYPE_ANNOTATION) < len(sqlglot.optimizer.optimizer.RULES)

_DEBUG_TYPE_ANNOTATIONS = False


class _ColumnRef(_FrozenModel):
Expand Down Expand Up @@ -385,11 +392,12 @@ def _sqlglot_force_column_normalizer(
schema=sqlglot_db_schema,
qualify_columns=True,
validate_qualify_columns=False,
allow_partial_qualification=True,
identify=True,
# sqlglot calls the db -> schema -> table hierarchy "catalog", "db", "table".
catalog=default_db,
db=default_schema,
rules=RULES_BEFORE_TYPE_ANNOTATION,
rules=_OPTIMIZE_RULES,
)
except (sqlglot.errors.OptimizeError, ValueError) as e:
raise SqlUnderstandingError(
Expand All @@ -408,6 +416,10 @@ def _sqlglot_force_column_normalizer(
except (sqlglot.errors.OptimizeError, sqlglot.errors.ParseError) as e:
# This is not a fatal error, so we can continue.
logger.debug("sqlglot failed to annotate or parse types: %s", e)
if _DEBUG_TYPE_ANNOTATIONS and logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Type annotated sql %s", statement.sql(pretty=True, dialect=dialect)
)

return statement, _ColumnResolver(
sqlglot_db_schema=sqlglot_db_schema,
Expand Down Expand Up @@ -907,6 +919,7 @@ def _sqlglot_lineage_inner(
# At this stage we only want to qualify the table names. The columns will be dealt with later.
qualify_columns=False,
validate_qualify_columns=False,
allow_partial_qualification=True,
# Only insert quotes where necessary.
identify=False,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from datahub.configuration.source_common import DEFAULT_ENV
from datahub.emitter.mce_builder import make_schema_field_urn
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.run.pipeline import Pipeline, PipelineContext
from datahub.ingestion.source.tableau.tableau import (
TableauConfig,
Expand All @@ -37,7 +38,7 @@
FineGrainedLineageUpstreamType,
UpstreamLineage,
)
from datahub.metadata.schema_classes import MetadataChangeProposalClass, UpstreamClass
from datahub.metadata.schema_classes import UpstreamClass
from tests.test_helpers import mce_helpers, test_connection_helpers
from tests.test_helpers.state_helpers import (
get_current_checkpoint_from_pipeline,
Expand Down Expand Up @@ -939,11 +940,12 @@ def test_tableau_unsupported_csql():
database_override_map={"production database": "prod"}
)

def test_lineage_metadata(
def check_lineage_metadata(
lineage, expected_entity_urn, expected_upstream_table, expected_cll
):
mcp = cast(MetadataChangeProposalClass, next(iter(lineage)).metadata)
assert mcp.aspect == UpstreamLineage(
mcp = cast(MetadataChangeProposalWrapper, list(lineage)[0].metadata)

expected = UpstreamLineage(
upstreams=[
UpstreamClass(
dataset=expected_upstream_table,
Expand All @@ -966,6 +968,9 @@ def test_lineage_metadata(
)
assert mcp.entityUrn == expected_entity_urn

actual_aspect = mcp.aspect
assert actual_aspect == expected

csql_urn = "urn:li:dataset:(urn:li:dataPlatform:tableau,09988088-05ad-173c-a2f1-f33ba3a13d1a,PROD)"
expected_upstream_table = "urn:li:dataset:(urn:li:dataPlatform:bigquery,my_bigquery_project.invent_dw.UserDetail,PROD)"
expected_cll = {
Expand Down Expand Up @@ -996,7 +1001,7 @@ def test_lineage_metadata(
},
out_columns=[],
)
test_lineage_metadata(
check_lineage_metadata(
lineage=lineage,
expected_entity_urn=csql_urn,
expected_upstream_table=expected_upstream_table,
Expand All @@ -1014,7 +1019,7 @@ def test_lineage_metadata(
},
out_columns=[],
)
test_lineage_metadata(
check_lineage_metadata(
lineage=lineage,
expected_entity_urn=csql_urn,
expected_upstream_table=expected_upstream_table,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
{
"query_type": "SELECT",
"query_type_props": {},
"query_fingerprint": "4094ebd230c1d47c7e6879b05ab927e550923b1986eb58c5f3814396cf401d18",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:bigquery,invent_dw.UserDetail,PROD)"
],
"out_tables": [],
"column_lineage": [
{
"downstream": {
"table": null,
"column": "user_id",
"column_type": null,
"native_column_type": null
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,invent_dw.UserDetail,PROD)",
"column": "user_id"
}
]
},
{
"downstream": {
"table": null,
"column": "source",
"column_type": null,
"native_column_type": null
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,invent_dw.UserDetail,PROD)",
"column": "source"
}
]
},
{
"downstream": {
"table": null,
"column": "user_source",
"column_type": null,
"native_column_type": null
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,invent_dw.UserDetail,PROD)",
"column": "user_source"
}
]
}
],
"debug_info": {
"confidence": 0.2,
"generalized_statement": "SELECT user_id, source, user_source FROM (SELECT *, ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY __partition_day DESC) AS rank_ FROM invent_dw.UserDetail) AS source_user WHERE rank_ = ?"
}
}
15 changes: 15 additions & 0 deletions metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,3 +1253,18 @@ def test_snowflake_drop_schema() -> None:
dialect="snowflake",
expected_file=RESOURCE_DIR / "test_snowflake_drop_schema.json",
)


def test_bigquery_subquery_column_inference() -> None:
assert_sql_result(
"""\
SELECT user_id, source, user_source
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY __partition_day DESC) AS rank_
FROM invent_dw.UserDetail
) source_user
WHERE rank_ = 1
""",
dialect="bigquery",
expected_file=RESOURCE_DIR / "test_bigquery_subquery_column_inference.json",
)

0 comments on commit aed3aa2

Please sign in to comment.