-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX: Stream DB Results to Parquet Files #183
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,8 @@ | |
import pandas | ||
import sqlalchemy as sa | ||
from sqlalchemy.orm import sessionmaker | ||
import pyarrow | ||
import pyarrow.parquet as pq | ||
|
||
from lamp_py.aws.s3 import get_datetime_from_partition_path | ||
from lamp_py.runtime_utils.process_logger import ProcessLogger | ||
|
@@ -28,6 +30,13 @@ def running_in_docker() -> bool: | |
) | ||
|
||
|
||
def running_in_aws() -> bool: | ||
""" | ||
return True if running on aws, else False | ||
""" | ||
return bool(os.getenv("AWS_DEFAULT_REGION")) | ||
|
||
|
||
def get_db_password() -> str: | ||
""" | ||
function to provide rds password | ||
|
@@ -86,12 +95,14 @@ def get_local_engine( | |
db_user = os.environ.get("DB_USER") | ||
db_ssl_options = "" | ||
|
||
# when using docker, the db host env var will be "local_rds" but | ||
# accessed via the "0.0.0.0" ip address (mac specific) | ||
# on mac, when running in docker locally db is accessed by "0.0.0.0" ip | ||
if db_host == "local_rds" and "macos" in platform.platform().lower(): | ||
db_host = "0.0.0.0" | ||
# if not running_in_docker(): | ||
# db_host = "127.0.0.1" | ||
|
||
# when running application locally in CLI for configuration | ||
# and debugging, db is accessed by localhost ip | ||
if not running_in_docker() and not running_in_aws(): | ||
db_host = "127.0.0.1" | ||
|
||
assert db_host is not None | ||
assert db_name is not None | ||
|
@@ -266,6 +277,44 @@ def select_as_list( | |
with self.session.begin() as cursor: | ||
return [row._asdict() for row in cursor.execute(select_query)] | ||
|
||
def write_to_parquet( | ||
self, | ||
select_query: sa.sql.selectable.Select, | ||
write_path: str, | ||
schema: pyarrow.schema, | ||
batch_size: int = 1024 * 1024, | ||
) -> str: | ||
""" | ||
stream db query results to parquet file in batches | ||
|
||
this function is meant to limit memory usage when creating very large | ||
parquet files from db SELECT | ||
|
||
default batch_size of 1024*1024 is based on "row_group_size" parameter | ||
of ParquetWriter.write_batch(): row group size will be the minimum of | ||
the RecordBatch size and 1024 * 1024. If set larger | ||
than 64Mi then 64Mi will be used instead. | ||
https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetWriter.html#pyarrow.parquet.ParquetWriter.write_batch | ||
|
||
:param select_query: query to execute | ||
:param write_path: local file path for resulting parquet file | ||
:param schema: schema of parquet file from select query | ||
:param batch_size: number of records to stream from db per batch | ||
|
||
:return local path to created parquet file | ||
""" | ||
with self.session.begin() as cursor: | ||
result = cursor.execute(select_query).yield_per(batch_size) | ||
with pq.ParquetWriter(write_path, schema=schema) as pq_writer: | ||
for part in result.partitions(): | ||
pq_writer.write_batch( | ||
pyarrow.RecordBatch.from_pylist( | ||
[row._asdict() for row in part], schema=schema | ||
) | ||
) | ||
|
||
return write_path | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i don't think we need to return the write path since it was provided as part of the input. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Incorporated |
||
|
||
def truncate_table( | ||
self, | ||
table_to_truncate: Any, | ||
|
@@ -291,9 +340,7 @@ def truncate_table( | |
self.execute(sa.text(f"{truncate_query};")) | ||
|
||
# Execute VACUUM to avoid non-deterministic behavior during testing | ||
with self.session.begin() as cursor: | ||
cursor.execute(sa.text("END TRANSACTION;")) | ||
cursor.execute(sa.text(f"VACUUM (ANALYZE) {truncat_as};")) | ||
self.vacuum_analyze(table_to_truncate) | ||
|
||
def vacuum_analyze(self, table: Any) -> None: | ||
"""RUN VACUUM (ANALYZE) on table""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
|
||
def validate_environment( | ||
required_variables: List[str], | ||
private_variables: Optional[List[str]] = None, | ||
optional_variables: Optional[List[str]] = None, | ||
validate_db: bool = False, | ||
) -> None: | ||
|
@@ -16,6 +17,9 @@ def validate_environment( | |
process_logger = ProcessLogger("validate_env") | ||
process_logger.log_start() | ||
|
||
if private_variables is None: | ||
private_variables = [] | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Had to re-structure this function to allow for a |
||
# every pipeline needs a service name for logging | ||
required_variables.append("SERVICE_NAME") | ||
|
||
|
@@ -27,29 +31,30 @@ def validate_environment( | |
"DB_PORT", | ||
"DB_USER", | ||
] | ||
# if db password is missing, db region is required to generate a | ||
# token to use as the password to the cloud database | ||
if os.environ.get("DB_PASSWORD", None) is None: | ||
required_variables.append("DB_REGION") | ||
|
||
# check for missing variables. add found variables to our logs. | ||
missing_required = [] | ||
for key in required_variables: | ||
value = os.environ.get(key, None) | ||
if value is None: | ||
missing_required.append(key) | ||
# do not log private variables | ||
if key in private_variables: | ||
value = "**********" | ||
process_logger.add_metadata(**{key: value}) | ||
|
||
# if db password is missing, db region is required to generate a token to | ||
# use as the password to the cloud database | ||
if validate_db: | ||
if os.environ.get("DB_PASSWORD", None) is None: | ||
value = os.environ.get("DB_REGION", None) | ||
if value is None: | ||
missing_required.append("DB_REGION") | ||
process_logger.add_metadata(DB_REGION=value) | ||
|
||
# for optional variables, access ones that exist and add them to logs. | ||
if optional_variables: | ||
for key in optional_variables: | ||
value = os.environ.get(key, None) | ||
if value is not None: | ||
# do not log private variables | ||
if key in private_variables: | ||
value = "**********" | ||
process_logger.add_metadata(**{key: value}) | ||
|
||
# if required variables are missing, log a failure and throw. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,8 +56,6 @@ def __init__( | |
"s3://", "" | ||
) | ||
|
||
self.db_manager = DatabaseManager() | ||
|
||
Comment on lines
-59
to
-60
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dropped
|
||
@property | ||
@abstractmethod | ||
def parquet_schema(self) -> pyarrow.schema: | ||
|
@@ -66,13 +64,13 @@ def parquet_schema(self) -> pyarrow.schema: | |
""" | ||
|
||
@abstractmethod | ||
def create_parquet(self) -> None: | ||
def create_parquet(self, db_manager: DatabaseManager) -> None: | ||
""" | ||
Business logic to create new Job parquet file | ||
""" | ||
|
||
@abstractmethod | ||
def update_parquet(self) -> bool: | ||
def update_parquet(self, db_manager: DatabaseManager) -> bool: | ||
""" | ||
Business logic to update existing Job parquet file | ||
|
@@ -261,7 +259,7 @@ def run_hyper(self) -> None: | |
if retry_count == max_retries - 1: | ||
process_log.log_failure(exception=exception) | ||
|
||
def run_parquet(self) -> None: | ||
def run_parquet(self, db_manager: DatabaseManager) -> None: | ||
""" | ||
Remote parquet Create / Update runner | ||
|
@@ -293,10 +291,10 @@ def run_parquet(self) -> None: | |
# remote schema does not match expected local schema | ||
run_action = "create" | ||
upload_parquet = True | ||
self.create_parquet() | ||
self.create_parquet(db_manager) | ||
else: | ||
run_action = "update" | ||
upload_parquet = self.update_parquet() | ||
upload_parquet = self.update_parquet(db_manager) | ||
|
||
parquet_file_size_mb = os.path.getsize(self.local_parquet_path) / ( | ||
1024 * 1024 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pulled this over from
dmap_import