Skip to content

Commit

Permalink
feat (ingest/delta-lake): Support ABS file location for delta lake ta…
Browse files Browse the repository at this point in the history
…bles
  • Loading branch information
acrylJonny committed Nov 3, 2024
1 parent 00a2751 commit d89b107
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from datahub.ingestion.source.aws.aws_common import AwsConnectionConfig
from datahub.ingestion.source.aws.s3_util import is_s3_uri

from datahub.ingestion.source.azure.azure_common import AzureConnectionConfig
from datahub.ingestion.source.azure.abs_utils import is_abs_uri

# hide annoying debug errors from py4j
logging.getLogger("py4j").setLevel(logging.ERROR)
logger: logging.Logger = logging.getLogger(__name__)
Expand All @@ -35,10 +38,19 @@ class S3(ConfigModel):
description="# Whether or not to create tags in datahub from the s3 object",
)

class Azure(ConfigModel):
"""Azure configuration for Delta Lake source"""
azure_config: Optional[AzureConnectionConfig] = Field(
default=None, description="Azure configuration"
)
use_abs_blob_tags: Optional[bool] = Field(
False,
description="Whether or not to create tags in datahub from Azure blob metadata",
)

class DeltaLakeSourceConfig(PlatformInstanceConfigMixin, EnvConfigMixin):
base_path: str = Field(
description="Path to table (s3 or local file system). If path is not a delta table path "
description="Path to table (s3, abfss, or local file system). If path is not a delta table path "
"then all subfolders will be scanned to detect and ingest delta tables."
)
relative_path: Optional[str] = Field(
Expand Down Expand Up @@ -73,11 +85,16 @@ class DeltaLakeSourceConfig(PlatformInstanceConfigMixin, EnvConfigMixin):
)

s3: Optional[S3] = Field()
azure: Optional[Azure] = Field()

@cached_property
def is_s3(self):
return is_s3_uri(self.base_path or "")

@cached_property
def is_azure(self):
return is_abs_uri(self.base_path or "")

@cached_property
def complete_path(self):
complete_path = self.base_path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def read_delta_table(
path: str, opts: Dict[str, str], delta_lake_config: DeltaLakeSourceConfig
) -> Optional[DeltaTable]:
if not delta_lake_config.is_s3 and not pathlib.Path(path).exists():
if not (delta_lake_config.is_s3 or delta_lake_config.is_azure) and not pathlib.Path(path).exists():
# The DeltaTable() constructor will create the path if it doesn't exist.
# Hence we need an extra, manual check here.
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
get_key_prefix,
strip_s3_prefix,
)
from datahub.ingestion.source.azure.abs_folder_utils import get_abs_tags
from datahub.ingestion.source.azure.abs_utils import (
get_container_name,
get_abs_prefix,
strip_abs_prefix,
)
from datahub.ingestion.source.data_lake_common.data_lake_utils import ContainerWUCreator
from datahub.ingestion.source.delta_lake.config import DeltaLakeSourceConfig
from datahub.ingestion.source.delta_lake.delta_lake_utils import (
Expand Down Expand Up @@ -110,6 +116,13 @@ def __init__(self, config: DeltaLakeSourceConfig, ctx: PipelineContext):
):
raise ValueError("AWS Config must be provided for S3 base path.")
self.s3_client = self.source_config.s3.aws_config.get_s3_client()
elif self.source_config.is_azure:
if (
self.source_config.azure is None
or self.source_config.azure.azure_config is None
):
raise ValueError("Azure Config must be provided for ABFSS base path")
self.azure_client = self.source_config.azure.get_azure_client()

# self.profiling_times_taken = []
config_report = {
Expand Down Expand Up @@ -203,9 +216,12 @@ def ingest_table(

logger.debug(f"Ingesting table {table_name} from location {path}")
if self.source_config.relative_path is None:
browse_path: str = (
strip_s3_prefix(path) if self.source_config.is_s3 else path.strip("/")
)
if self.source_config.is_s3:
browse_path = strip_s3_prefix(path)
elif self.source_config.is_azure:
browse_path = strip_abs_prefix(path)
else:
browse_path = path.strip("/")
else:
browse_path = path.split(self.source_config.base_path)[1].strip("/")

Expand Down Expand Up @@ -271,6 +287,26 @@ def ingest_table(
)
if s3_tags is not None:
dataset_snapshot.aspects.append(s3_tags)
if (
self.source_config.is_azure
and self.source_config.azure
and (
self.source_config.azure.use_azure_container_tags
or self.source_config.azure.use_azure_blob_tags
)
):
container_name = get_container_name(path)
abs_prefix = get_abs_prefix(path)
abs_tags = get_abs_tags(
container_name,
abs_prefix,
dataset_urn,
self.source_config.azure.azure_config,
self.ctx,
self.source_config.azure.use_abs_blob_tags,
)
if abs_tags is not None:
dataset_snapshot.aspects.append(abs_tags)
mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot)
yield MetadataWorkUnit(id=str(delta_table.metadata().id), mce=mce)

Expand Down Expand Up @@ -301,6 +337,19 @@ def get_storage_options(self) -> Dict[str, str]:
if aws_config.aws_endpoint_url:
opts["AWS_ENDPOINT_URL"] = aws_config.aws_endpoint_url
return opts
elif self.source_config.is_azure:
azure_config = self.source_config.azure.azure_config
creds = azure_config.get_credentials()
opts = {
"AZURE_STORAGE_ACCOUNT_NAME": azure_config.account_name,
"AZURE_STORAGE_CONTAINER_NAME": azure_config.container_name,
"AZURE_TENANT_ID": creds.get("tenant_id") or "",
"AZURE_CLIENT_ID": creds.get("client_id") or "",
"AZURE_CLIENT_SECRET": creds.get("client_secret") or "",
"AZURE_STORAGE_SAS_TOKEN": creds.get("sas_token") or "",
"AZURE_STORAGE_ACCOUNT_KEY": creds.get("account_key") or ""
}
return opts
else:
return {}

Expand All @@ -317,6 +366,8 @@ def process_folder(self, path: str) -> Iterable[MetadataWorkUnit]:
def get_folders(self, path: str) -> Iterable[str]:
if self.source_config.is_s3:
return self.s3_get_folders(path)
elif self.source_config.is_azure:
return self.azure_get_folders(path)
else:
return self.local_get_folders(path)

Expand All @@ -328,6 +379,19 @@ def s3_get_folders(self, path: str) -> Iterable[str]:
for o in page.get("CommonPrefixes", []):
yield f"{parse_result.scheme}://{parse_result.netloc}/{o.get('Prefix')}"

def azure_get_folders(self, path: str) -> Iterable[str]:
"""List folders from Azure Storage."""
parsed = urlparse(path)
prefix = parsed.path.lstrip('/')
container_client = self.azure_client.get_container_client(parsed.netloc.split('@')[0])

try:
for item in container_client.walk_blobs(name_starts_with=prefix):
if isinstance(item, dict) and item.get("is_directory", False):
yield f"abfss://{parsed.netloc}/{item['name']}"
except Exception as e:
self.report.report_failure("azure-folders", f"Failed to list ABFSS folders: {e}")

def local_get_folders(self, path: str) -> Iterable[str]:
if not os.path.isdir(path):
raise FileNotFoundError(
Expand Down

0 comments on commit d89b107

Please sign in to comment.