Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Stream DB Results to Parquet Files #183

Merged
merged 2 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python_src/src/lamp_py/performance_manager/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from lamp_py.runtime_utils.env_validation import validate_environment
from lamp_py.runtime_utils.process_logger import ProcessLogger

# from lamp_py.tableau.pipeline import start_parquet_updates
from lamp_py.tableau.pipeline import start_parquet_updates

from .flat_file import write_flat_files
from .l0_gtfs_rt_events import process_gtfs_rt_files
Expand Down Expand Up @@ -69,7 +69,7 @@ def iteration() -> None:
process_static_tables(db_manager)
process_gtfs_rt_files(db_manager)
write_flat_files(db_manager)
# start_parquet_updates()
start_parquet_updates(db_manager)

process_logger.log_complete()
except Exception as exception:
Expand Down
61 changes: 54 additions & 7 deletions python_src/src/lamp_py/postgres/postgres_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import pandas
import sqlalchemy as sa
from sqlalchemy.orm import sessionmaker
import pyarrow
import pyarrow.parquet as pq

from lamp_py.aws.s3 import get_datetime_from_partition_path
from lamp_py.runtime_utils.process_logger import ProcessLogger
Expand All @@ -28,6 +30,13 @@ def running_in_docker() -> bool:
)


def running_in_aws() -> bool:
"""
return True if running on aws, else False
"""
return bool(os.getenv("AWS_DEFAULT_REGION"))


def get_db_password() -> str:
"""
function to provide rds password
Expand Down Expand Up @@ -86,12 +95,14 @@ def get_local_engine(
db_user = os.environ.get("DB_USER")
db_ssl_options = ""

# when using docker, the db host env var will be "local_rds" but
# accessed via the "0.0.0.0" ip address (mac specific)
# on mac, when running in docker locally db is accessed by "0.0.0.0" ip
if db_host == "local_rds" and "macos" in platform.platform().lower():
db_host = "0.0.0.0"
# if not running_in_docker():
# db_host = "127.0.0.1"

# when running application locally in CLI for configuration
# and debugging, db is accessed by localhost ip
if not running_in_docker() and not running_in_aws():
db_host = "127.0.0.1"
Comment on lines +104 to +105
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pulled this over from dmap_import


assert db_host is not None
assert db_name is not None
Expand Down Expand Up @@ -266,6 +277,44 @@ def select_as_list(
with self.session.begin() as cursor:
return [row._asdict() for row in cursor.execute(select_query)]

def write_to_parquet(
self,
select_query: sa.sql.selectable.Select,
write_path: str,
schema: pyarrow.schema,
batch_size: int = 1024 * 1024,
) -> str:
"""
stream db query results to parquet file in batches

this function is meant to limit memory usage when creating very large
parquet files from db SELECT

default batch_size of 1024*1024 is based on "row_group_size" parameter
of ParquetWriter.write_batch(): row group size will be the minimum of
the RecordBatch size and 1024 * 1024. If set larger
than 64Mi then 64Mi will be used instead.
https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetWriter.html#pyarrow.parquet.ParquetWriter.write_batch

:param select_query: query to execute
:param write_path: local file path for resulting parquet file
:param schema: schema of parquet file from select query
:param batch_size: number of records to stream from db per batch

:return local path to created parquet file
"""
with self.session.begin() as cursor:
result = cursor.execute(select_query).yield_per(batch_size)
with pq.ParquetWriter(write_path, schema=schema) as pq_writer:
for part in result.partitions():
pq_writer.write_batch(
pyarrow.RecordBatch.from_pylist(
[row._asdict() for row in part], schema=schema
)
)

return write_path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think we need to return the write path since it was provided as part of the input.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorporated


def truncate_table(
self,
table_to_truncate: Any,
Expand All @@ -291,9 +340,7 @@ def truncate_table(
self.execute(sa.text(f"{truncate_query};"))

# Execute VACUUM to avoid non-deterministic behavior during testing
with self.session.begin() as cursor:
cursor.execute(sa.text("END TRANSACTION;"))
cursor.execute(sa.text(f"VACUUM (ANALYZE) {truncat_as};"))
self.vacuum_analyze(table_to_truncate)

def vacuum_analyze(self, table: Any) -> None:
"""RUN VACUUM (ANALYZE) on table"""
Expand Down
23 changes: 14 additions & 9 deletions python_src/src/lamp_py/runtime_utils/env_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

def validate_environment(
required_variables: List[str],
private_variables: Optional[List[str]] = None,
optional_variables: Optional[List[str]] = None,
validate_db: bool = False,
) -> None:
Expand All @@ -16,6 +17,9 @@ def validate_environment(
process_logger = ProcessLogger("validate_env")
process_logger.log_start()

if private_variables is None:
private_variables = []

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to re-structure this function to allow for a private_variable parameter and avoid a pylint too-many-branches flag

# every pipeline needs a service name for logging
required_variables.append("SERVICE_NAME")

Expand All @@ -27,29 +31,30 @@ def validate_environment(
"DB_PORT",
"DB_USER",
]
# if db password is missing, db region is required to generate a
# token to use as the password to the cloud database
if os.environ.get("DB_PASSWORD", None) is None:
required_variables.append("DB_REGION")

# check for missing variables. add found variables to our logs.
missing_required = []
for key in required_variables:
value = os.environ.get(key, None)
if value is None:
missing_required.append(key)
# do not log private variables
if key in private_variables:
value = "**********"
process_logger.add_metadata(**{key: value})

# if db password is missing, db region is required to generate a token to
# use as the password to the cloud database
if validate_db:
if os.environ.get("DB_PASSWORD", None) is None:
value = os.environ.get("DB_REGION", None)
if value is None:
missing_required.append("DB_REGION")
process_logger.add_metadata(DB_REGION=value)

# for optional variables, access ones that exist and add them to logs.
if optional_variables:
for key in optional_variables:
value = os.environ.get(key, None)
if value is not None:
# do not log private variables
if key in private_variables:
value = "**********"
process_logger.add_metadata(**{key: value})

# if required variables are missing, log a failure and throw.
Expand Down
12 changes: 5 additions & 7 deletions python_src/src/lamp_py/tableau/hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ def __init__(
"s3://", ""
)

self.db_manager = DatabaseManager()

Comment on lines -59 to -60
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dropped self.db_manager from being created for all hyper jobs. This would have led to issues with the Hyper file writing ecs that won't be able to connect to our rds.

db_manager is now passed directly into create_parquet and update_parquet methods, as they are the only portions of the class that require db access.

@property
@abstractmethod
def parquet_schema(self) -> pyarrow.schema:
Expand All @@ -66,13 +64,13 @@ def parquet_schema(self) -> pyarrow.schema:
"""

@abstractmethod
def create_parquet(self) -> None:
def create_parquet(self, db_manager: DatabaseManager) -> None:
"""
Business logic to create new Job parquet file
"""

@abstractmethod
def update_parquet(self) -> bool:
def update_parquet(self, db_manager: DatabaseManager) -> bool:
"""
Business logic to update existing Job parquet file
Expand Down Expand Up @@ -261,7 +259,7 @@ def run_hyper(self) -> None:
if retry_count == max_retries - 1:
process_log.log_failure(exception=exception)

def run_parquet(self) -> None:
def run_parquet(self, db_manager: DatabaseManager) -> None:
"""
Remote parquet Create / Update runner
Expand Down Expand Up @@ -293,10 +291,10 @@ def run_parquet(self) -> None:
# remote schema does not match expected local schema
run_action = "create"
upload_parquet = True
self.create_parquet()
self.create_parquet(db_manager)
else:
run_action = "update"
upload_parquet = self.update_parquet()
upload_parquet = self.update_parquet(db_manager)

parquet_file_size_mb = os.path.getsize(self.local_parquet_path) / (
1024 * 1024
Expand Down
57 changes: 31 additions & 26 deletions python_src/src/lamp_py/tableau/jobs/gtfs_rail.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import pyarrow
import pyarrow.parquet as pq
import pyarrow.dataset as pd
import sqlalchemy as sa

from lamp_py.tableau.hyper import HyperJob
from lamp_py.aws.s3 import download_file
from lamp_py.postgres.postgres_utils import DatabaseManager


class HyperGTFS(HyperJob):
Expand Down Expand Up @@ -34,21 +36,17 @@ def __init__(
def parquet_schema(self) -> pyarrow.schema:
"""Define GTFS Table Schema"""

def create_parquet(self) -> None:
def create_parquet(self, db_manager: DatabaseManager) -> None:
if os.path.exists(self.local_parquet_path):
os.remove(self.local_parquet_path)

pq.write_table(
pyarrow.Table.from_pylist(
mapping=self.db_manager.select_as_list(
sa.text(self.create_query)
),
schema=self.parquet_schema,
),
self.local_parquet_path,
db_manager.write_to_parquet(
select_query=sa.text(self.create_query),
write_path=self.local_parquet_path,
schema=self.parquet_schema,
)

def update_parquet(self) -> bool:
def update_parquet(self, db_manager: DatabaseManager) -> bool:
download_file(
object_path=self.remote_parquet_path,
file_name=self.local_parquet_path,
Expand All @@ -62,9 +60,7 @@ def update_parquet(self) -> bool:
f"SELECT MAX(static_version_key) FROM {self.gtfs_table_name};"
)

max_db_key = self.db_manager.select_as_list(sa.text(max_key_query))[0][
"max"
]
max_db_key = db_manager.select_as_list(sa.text(max_key_query))[0]["max"]

# no update needed
if max_db_key <= max_parquet_key:
Expand All @@ -75,21 +71,30 @@ def update_parquet(self) -> bool:
f" WHERE static_version_key > {max_parquet_key} ",
)

pq.write_table(
pyarrow.concat_tables(
[
pq.read_table(self.local_parquet_path),
pyarrow.Table.from_pylist(
mapping=self.db_manager.select_as_list(
sa.text(update_query)
),
schema=self.parquet_schema,
),
]
),
self.local_parquet_path,
db_parquet_path = db_manager.write_to_parquet(
select_query=sa.text(update_query),
write_path="/tmp/db_local.parquet",
schema=self.parquet_schema,
)

old_ds = pd.dataset(self.local_parquet_path)
new_ds = pd.dataset(db_parquet_path)

combine_parquet_path = "/tmp/combine.parquet"
combine_batches = pd.dataset(
[old_ds, new_ds],
schema=self.parquet_schema,
).to_batches(batch_size=1024 * 1024)

with pq.ParquetWriter(
combine_parquet_path, schema=self.parquet_schema
) as writer:
for batch in combine_batches:
writer.write_batch(batch)

os.replace(combine_parquet_path, self.local_parquet_path)
os.remove(db_parquet_path)

return True


Expand Down
Loading
Loading