Skip to content

Commit

Permalink
Merge branch 'datahub-project:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
anshbansal authored May 23, 2024
2 parents d136e7d + e361d28 commit 8c8127f
Show file tree
Hide file tree
Showing 12 changed files with 957 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type Props = {
export const ViewDropdownMenu = ({
view,
visible,
isOwnedByUser,
isOwnedByUser = view.viewType === DataHubViewType.Personal,
trigger = 'hover',
onClickEdit,
onClickPreview,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ private UrnUtils() {}
@Nonnull
public static DatasetUrn toDatasetUrn(
@Nonnull String platformName, @Nonnull String datasetName, @Nonnull String origin) {
return new DatasetUrn(new DataPlatformUrn(platformName), datasetName, FabricType.valueOf(origin.toUpperCase()));
return new DatasetUrn(
new DataPlatformUrn(platformName), datasetName, FabricType.valueOf(origin.toUpperCase()));
}

/**
Expand Down
142 changes: 86 additions & 56 deletions metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,17 @@

@dataclass
class DBTSourceReport(StaleEntityRemovalSourceReport):
sql_statements_parsed: int = 0
sql_statements_table_error: int = 0
sql_statements_column_error: int = 0
sql_parser_detach_ctes_failures: LossyList[str] = field(default_factory=LossyList)
sql_parser_skipped_missing_code: LossyList[str] = field(default_factory=LossyList)
sql_parser_parse_failures: int = 0
sql_parser_detach_ctes_failures: int = 0
sql_parser_table_errors: int = 0
sql_parser_column_errors: int = 0
sql_parser_successes: int = 0

sql_parser_parse_failures_list: LossyList[str] = field(default_factory=LossyList)
sql_parser_detach_ctes_failures_list: LossyList[str] = field(
default_factory=LossyList
)

in_manifest_but_missing_catalog: LossyList[str] = field(default_factory=LossyList)

Expand Down Expand Up @@ -558,10 +564,11 @@ def get_fake_ephemeral_table_name(self) -> str:
assert self.is_ephemeral_model()

# Similar to get_db_fqn.
fqn = self._join_parts(
db_fqn = self._join_parts(
[self.database, self.schema, f"__datahub__dbt__ephemeral__{self.name}"]
)
return fqn.replace('"', "")
db_fqn = db_fqn.lower()
return db_fqn.replace('"', "")

def get_urn_for_upstream_lineage(
self,
Expand Down Expand Up @@ -819,9 +826,10 @@ def get_column_type(

# if still not found, report the warning
if TypeClass is None:
report.report_warning(
dataset_name, f"unable to map type {column_type} to metadata schema"
)
if column_type:
report.report_warning(
dataset_name, f"unable to map type {column_type} to metadata schema"
)
TypeClass = NullTypeClass

return SchemaFieldDataType(type=TypeClass())
Expand Down Expand Up @@ -1041,15 +1049,16 @@ def _infer_schemas_and_update_cll( # noqa: C901

# Iterate over the dbt nodes in topological order.
# This ensures that we process upstream nodes before downstream nodes.
for dbt_name in topological_sort(
node_order = topological_sort(
list(all_nodes_map.keys()),
edges=list(
(upstream, node.dbt_name)
for node in all_nodes_map.values()
for upstream in node.upstream_nodes
if upstream in all_nodes_map
),
):
)
for dbt_name in node_order:
node = all_nodes_map[dbt_name]
logger.debug(f"Processing CLL/schemas for {node.dbt_name}")

Expand Down Expand Up @@ -1119,55 +1128,26 @@ def _infer_schemas_and_update_cll( # noqa: C901

# Run sql parser to infer the schema + generate column lineage.
sql_result = None
if node.node_type in {"source", "test"}:
if node.node_type in {"source", "test", "seed"}:
# For sources, we generate CLL as a 1:1 mapping.
# We don't support CLL for tests (assertions).
# We don't support CLL for tests (assertions) or seeds.
pass
elif node.compiled_code:
try:
# Add CTE stops based on the upstreams list.
cte_mapping = {
cte_name: upstream_node.get_fake_ephemeral_table_name()
for upstream_node in [
all_nodes_map[upstream_node_name]
for upstream_node_name in node.upstream_nodes
if upstream_node_name in all_nodes_map
]
if upstream_node.is_ephemeral_model()
for cte_name in _get_dbt_cte_names(
upstream_node.name, schema_resolver.platform
)
}
preprocessed_sql = detach_ctes(
parse_statements_and_pick(
node.compiled_code,
platform=schema_resolver.platform,
),
platform=schema_resolver.platform,
cte_mapping=cte_mapping,
)
except Exception as e:
self.report.sql_parser_detach_ctes_failures.append(node.dbt_name)
logger.debug(
f"Failed to detach CTEs from compiled code. {node.dbt_name} will not have column lineage."
)
sql_result = SqlParsingResult.make_from_error(e)
else:
sql_result = sqlglot_lineage(
preprocessed_sql, schema_resolver=schema_resolver
# Add CTE stops based on the upstreams list.
cte_mapping = {
cte_name: upstream_node.get_fake_ephemeral_table_name()
for upstream_node in [
all_nodes_map[upstream_node_name]
for upstream_node_name in node.upstream_nodes
if upstream_node_name in all_nodes_map
]
if upstream_node.is_ephemeral_model()
for cte_name in _get_dbt_cte_names(
upstream_node.name, schema_resolver.platform
)
if sql_result.debug_info.error:
self.report.sql_statements_table_error += 1
logger.info(
f"Failed to parse compiled code for {node.dbt_name}: {sql_result.debug_info.error}"
)
elif sql_result.debug_info.column_error:
self.report.sql_statements_column_error += 1
logger.info(
f"Failed to generate CLL for {node.dbt_name}: {sql_result.debug_info.column_error}"
)
else:
self.report.sql_statements_parsed += 1
}

sql_result = self._parse_cll(node, cte_mapping, schema_resolver)
else:
self.report.sql_parser_skipped_missing_code.append(node.dbt_name)

Expand Down Expand Up @@ -1212,6 +1192,56 @@ def _infer_schemas_and_update_cll( # noqa: C901
if inferred_schema_fields:
node.columns_setdefault(inferred_schema_fields)

def _parse_cll(
self,
node: DBTNode,
cte_mapping: Dict[str, str],
schema_resolver: SchemaResolver,
) -> SqlParsingResult:
assert node.compiled_code is not None

try:
picked_statement = parse_statements_and_pick(
node.compiled_code,
platform=schema_resolver.platform,
)
except Exception as e:
logger.debug(
f"Failed to parse compiled code. {node.dbt_name} will not have column lineage."
)
self.report.sql_parser_parse_failures += 1
self.report.sql_parser_parse_failures_list.append(node.dbt_name)
return SqlParsingResult.make_from_error(e)

try:
preprocessed_sql = detach_ctes(
picked_statement,
platform=schema_resolver.platform,
cte_mapping=cte_mapping,
)
except Exception as e:
self.report.sql_parser_detach_ctes_failures += 1
self.report.sql_parser_detach_ctes_failures_list.append(node.dbt_name)
logger.debug(
f"Failed to detach CTEs from compiled code. {node.dbt_name} will not have column lineage."
)
return SqlParsingResult.make_from_error(e)

sql_result = sqlglot_lineage(preprocessed_sql, schema_resolver=schema_resolver)
if sql_result.debug_info.table_error:
self.report.sql_parser_table_errors += 1
logger.info(
f"Failed to generate any CLL lineage for {node.dbt_name}: {sql_result.debug_info.error}"
)
elif sql_result.debug_info.column_error:
self.report.sql_parser_column_errors += 1
logger.info(
f"Failed to generate CLL for {node.dbt_name}: {sql_result.debug_info.column_error}"
)
else:
self.report.sql_parser_successes += 1
return sql_result

def create_dbt_platform_mces(
self,
dbt_nodes: List[DBTNode],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
TagPropertiesClass,
TagSnapshotClass,
)
from datahub.metadata.urns import TagUrn
from datahub.utilities.lossy_collections import LossyList, LossySet
from datahub.utilities.url_util import remove_port_from_url

Expand Down Expand Up @@ -669,6 +670,7 @@ class LookerExplore:
joins: Optional[List[str]] = None
fields: Optional[List[ViewField]] = None # the fields exposed in this explore
source_file: Optional[str] = None
tags: List[str] = dataclasses_field(default_factory=list)

@validator("name")
def remove_quotes(cls, v):
Expand Down Expand Up @@ -770,6 +772,7 @@ def from_dict(
# This method is getting called from lookml_source's get_internal_workunits method
# & upstream_views_file_path is not in use in that code flow
upstream_views_file_path={},
tags=cast(List, dict.get("tags")) if dict.get("tags") is not None else [],
)

@classmethod # noqa: C901
Expand All @@ -786,7 +789,6 @@ def from_api( # noqa: C901
try:
explore = client.lookml_model_explore(model, explore_name)
views: Set[str] = set()

lkml_fields: List[
LookmlModelExploreField
] = explore_field_set_to_lkml_fields(explore)
Expand Down Expand Up @@ -956,6 +958,7 @@ def from_api( # noqa: C901
),
upstream_views_file_path=upstream_views_file_path,
source_file=explore.source_file,
tags=list(explore.tags) if explore.tags is not None else [],
)
except SDKError as e:
if "<title>Looker Not Found (404)</title>" in str(e):
Expand Down Expand Up @@ -1133,6 +1136,20 @@ def _to_metadata_events( # noqa: C901
mcp,
]

# Add tags
explore_tag_urns: List[TagAssociationClass] = []
for tag in self.tags:
tag_urn = TagUrn(tag)
explore_tag_urns.append(TagAssociationClass(tag_urn.urn()))
proposals.append(
MetadataChangeProposalWrapper(
entityUrn=tag_urn.urn(),
aspect=tag_urn.to_key_aspect(),
)
)
if explore_tag_urns:
dataset_snapshot.aspects.append(GlobalTagsClass(explore_tag_urns))

# If extracting embeds is enabled, produce an MCP for embed URL.
if extract_embed_urls:
embed_mcp = create_embed_mcp(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
create_dataset_props_patch_builder,
)
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws import s3_util
from datahub.ingestion.source.aws.s3_util import (
make_s3_urn_for_lineage,
strip_s3_prefix,
Expand Down Expand Up @@ -512,14 +513,16 @@ def process_table(self, table: Table, schema: Schema) -> Iterable[MetadataWorkUn
if table.view_definition:
self.view_definitions[dataset_urn] = (table.ref, table.view_definition)

# generate sibling and lineage aspects in case of EXTERNAL DELTA TABLE
if (
table_props.customProperties.get("table_type") == "EXTERNAL"
table_props.customProperties.get("table_type")
in {"EXTERNAL", "HIVE_EXTERNAL_TABLE"}
and table_props.customProperties.get("data_source_format") == "DELTA"
and self.config.emit_siblings
):
storage_location = str(table_props.customProperties.get("storage_location"))
if storage_location.startswith("s3://"):
if any(
storage_location.startswith(prefix) for prefix in s3_util.S3_PREFIXES
):
browse_path = strip_s3_prefix(storage_location)
source_dataset_urn = make_dataset_urn_with_platform_instance(
"delta-lake",
Expand Down
20 changes: 14 additions & 6 deletions metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def _column_level_lineage( # noqa: C901
col_normalized = col

table_schema_normalized_mapping[table][col_normalized] = col
normalized_table_schema[col_normalized] = col_type
normalized_table_schema[col_normalized] = col_type or "UNKNOWN"

sqlglot_db_schema.add_table(
table.as_sqlglot_table(),
Expand Down Expand Up @@ -923,12 +923,20 @@ def _sqlglot_lineage_inner(
out_urns = sorted({table_name_urn_mapping[table] for table in modified})
column_lineage_urns = None
if column_lineage:
column_lineage_urns = [
_translate_internal_column_lineage(
table_name_urn_mapping, internal_col_lineage, dialect=dialect
try:
column_lineage_urns = [
_translate_internal_column_lineage(
table_name_urn_mapping, internal_col_lineage, dialect=dialect
)
for internal_col_lineage in column_lineage
]
except KeyError as e:
# When this happens, it's usually because of things like PIVOT where we can't
# really go up the scope chain.
logger.debug(
f"Failed to translate column lineage to urns: {e}", exc_info=True
)
for internal_col_lineage in column_lineage
]
debug_info.column_error = e

query_type, query_type_props = get_query_type_of_sql(
original_statement, dialect=dialect
Expand Down
2 changes: 2 additions & 0 deletions metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def parse_statement(


def parse_statements_and_pick(sql: str, platform: DialectOrStr) -> sqlglot.Expression:
logger.debug("Parsing SQL query: %s", sql)

dialect = get_dialect(platform)
statements = [
expression for expression in sqlglot.parse(sql, dialect=dialect) if expression
Expand Down
Loading

0 comments on commit 8c8127f

Please sign in to comment.