Skip to content

Commit

Permalink
FIX: Stream DB Results to Parquet Files (#183)
Browse files Browse the repository at this point in the history
The current Parquet -> Tableau pipeline process is flawed in the amount of memory required to create parquet files from DB SELECT queries. This change is meant to result in a fixed amount of memory usage, no matter what the number of results are returned from a DB query, when creating a parquet file.

This fixed memory usage is achieved by utilizing the yield_per method of the SQLAlchemy Result object, as well as the RecordBatch object of the pyarrow library.

In testing, memory usage for the creation of a parquet file from the static_stop_times table maxes out at approximately 5-6GB.

If memory usage needs to be further limited, the write_to_parquet function of DatabaseManager offers a batch_size parameter to limit the number for records flowing into a parquet file per partition.

Asana Task: https://app.asana.com/0/1205827492903547/1205940053804614
  • Loading branch information
rymarczy authored Nov 13, 2023
1 parent 53f4f37 commit bc28218
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 79 deletions.
4 changes: 2 additions & 2 deletions python_src/src/lamp_py/performance_manager/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from lamp_py.runtime_utils.env_validation import validate_environment
from lamp_py.runtime_utils.process_logger import ProcessLogger

# from lamp_py.tableau.pipeline import start_parquet_updates
from lamp_py.tableau.pipeline import start_parquet_updates

from .flat_file import write_flat_files
from .l0_gtfs_rt_events import process_gtfs_rt_files
Expand Down Expand Up @@ -69,7 +69,7 @@ def iteration() -> None:
process_static_tables(db_manager)
process_gtfs_rt_files(db_manager)
write_flat_files(db_manager)
# start_parquet_updates()
start_parquet_updates(db_manager)

process_logger.log_complete()
except Exception as exception:
Expand Down
57 changes: 50 additions & 7 deletions python_src/src/lamp_py/postgres/postgres_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import pandas
import sqlalchemy as sa
from sqlalchemy.orm import sessionmaker
import pyarrow
import pyarrow.parquet as pq

from lamp_py.aws.s3 import get_datetime_from_partition_path
from lamp_py.runtime_utils.process_logger import ProcessLogger
Expand All @@ -28,6 +30,13 @@ def running_in_docker() -> bool:
)


def running_in_aws() -> bool:
"""
return True if running on aws, else False
"""
return bool(os.getenv("AWS_DEFAULT_REGION"))


def get_db_password() -> str:
"""
function to provide rds password
Expand Down Expand Up @@ -86,12 +95,14 @@ def get_local_engine(
db_user = os.environ.get("DB_USER")
db_ssl_options = ""

# when using docker, the db host env var will be "local_rds" but
# accessed via the "0.0.0.0" ip address (mac specific)
# on mac, when running in docker locally db is accessed by "0.0.0.0" ip
if db_host == "local_rds" and "macos" in platform.platform().lower():
db_host = "0.0.0.0"
# if not running_in_docker():
# db_host = "127.0.0.1"

# when running application locally in CLI for configuration
# and debugging, db is accessed by localhost ip
if not running_in_docker() and not running_in_aws():
db_host = "127.0.0.1"

assert db_host is not None
assert db_name is not None
Expand Down Expand Up @@ -266,6 +277,40 @@ def select_as_list(
with self.session.begin() as cursor:
return [row._asdict() for row in cursor.execute(select_query)]

def write_to_parquet(
self,
select_query: sa.sql.selectable.Select,
write_path: str,
schema: pyarrow.schema,
batch_size: int = 1024 * 1024,
) -> None:
"""
stream db query results to parquet file in batches
this function is meant to limit memory usage when creating very large
parquet files from db SELECT
default batch_size of 1024*1024 is based on "row_group_size" parameter
of ParquetWriter.write_batch(): row group size will be the minimum of
the RecordBatch size and 1024 * 1024. If set larger
than 64Mi then 64Mi will be used instead.
https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetWriter.html#pyarrow.parquet.ParquetWriter.write_batch
:param select_query: query to execute
:param write_path: local file path for resulting parquet file
:param schema: schema of parquet file from select query
:param batch_size: number of records to stream from db per batch
"""
with self.session.begin() as cursor:
result = cursor.execute(select_query).yield_per(batch_size)
with pq.ParquetWriter(write_path, schema=schema) as pq_writer:
for part in result.partitions():
pq_writer.write_batch(
pyarrow.RecordBatch.from_pylist(
[row._asdict() for row in part], schema=schema
)
)

def truncate_table(
self,
table_to_truncate: Any,
Expand All @@ -291,9 +336,7 @@ def truncate_table(
self.execute(sa.text(f"{truncate_query};"))

# Execute VACUUM to avoid non-deterministic behavior during testing
with self.session.begin() as cursor:
cursor.execute(sa.text("END TRANSACTION;"))
cursor.execute(sa.text(f"VACUUM (ANALYZE) {truncat_as};"))
self.vacuum_analyze(table_to_truncate)

def vacuum_analyze(self, table: Any) -> None:
"""RUN VACUUM (ANALYZE) on table"""
Expand Down
23 changes: 14 additions & 9 deletions python_src/src/lamp_py/runtime_utils/env_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

def validate_environment(
required_variables: List[str],
private_variables: Optional[List[str]] = None,
optional_variables: Optional[List[str]] = None,
validate_db: bool = False,
) -> None:
Expand All @@ -16,6 +17,9 @@ def validate_environment(
process_logger = ProcessLogger("validate_env")
process_logger.log_start()

if private_variables is None:
private_variables = []

# every pipeline needs a service name for logging
required_variables.append("SERVICE_NAME")

Expand All @@ -27,29 +31,30 @@ def validate_environment(
"DB_PORT",
"DB_USER",
]
# if db password is missing, db region is required to generate a
# token to use as the password to the cloud database
if os.environ.get("DB_PASSWORD", None) is None:
required_variables.append("DB_REGION")

# check for missing variables. add found variables to our logs.
missing_required = []
for key in required_variables:
value = os.environ.get(key, None)
if value is None:
missing_required.append(key)
# do not log private variables
if key in private_variables:
value = "**********"
process_logger.add_metadata(**{key: value})

# if db password is missing, db region is required to generate a token to
# use as the password to the cloud database
if validate_db:
if os.environ.get("DB_PASSWORD", None) is None:
value = os.environ.get("DB_REGION", None)
if value is None:
missing_required.append("DB_REGION")
process_logger.add_metadata(DB_REGION=value)

# for optional variables, access ones that exist and add them to logs.
if optional_variables:
for key in optional_variables:
value = os.environ.get(key, None)
if value is not None:
# do not log private variables
if key in private_variables:
value = "**********"
process_logger.add_metadata(**{key: value})

# if required variables are missing, log a failure and throw.
Expand Down
12 changes: 5 additions & 7 deletions python_src/src/lamp_py/tableau/hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ def __init__(
"s3://", ""
)

self.db_manager = DatabaseManager()

@property
@abstractmethod
def parquet_schema(self) -> pyarrow.schema:
Expand All @@ -66,13 +64,13 @@ def parquet_schema(self) -> pyarrow.schema:
"""

@abstractmethod
def create_parquet(self) -> None:
def create_parquet(self, db_manager: DatabaseManager) -> None:
"""
Business logic to create new Job parquet file
"""

@abstractmethod
def update_parquet(self) -> bool:
def update_parquet(self, db_manager: DatabaseManager) -> bool:
"""
Business logic to update existing Job parquet file
Expand Down Expand Up @@ -261,7 +259,7 @@ def run_hyper(self) -> None:
if retry_count == max_retries - 1:
process_log.log_failure(exception=exception)

def run_parquet(self) -> None:
def run_parquet(self, db_manager: DatabaseManager) -> None:
"""
Remote parquet Create / Update runner
Expand Down Expand Up @@ -293,10 +291,10 @@ def run_parquet(self) -> None:
# remote schema does not match expected local schema
run_action = "create"
upload_parquet = True
self.create_parquet()
self.create_parquet(db_manager)
else:
run_action = "update"
upload_parquet = self.update_parquet()
upload_parquet = self.update_parquet(db_manager)

parquet_file_size_mb = os.path.getsize(self.local_parquet_path) / (
1024 * 1024
Expand Down
58 changes: 32 additions & 26 deletions python_src/src/lamp_py/tableau/jobs/gtfs_rail.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import pyarrow
import pyarrow.parquet as pq
import pyarrow.dataset as pd
import sqlalchemy as sa

from lamp_py.tableau.hyper import HyperJob
from lamp_py.aws.s3 import download_file
from lamp_py.postgres.postgres_utils import DatabaseManager


class HyperGTFS(HyperJob):
Expand Down Expand Up @@ -34,21 +36,17 @@ def __init__(
def parquet_schema(self) -> pyarrow.schema:
"""Define GTFS Table Schema"""

def create_parquet(self) -> None:
def create_parquet(self, db_manager: DatabaseManager) -> None:
if os.path.exists(self.local_parquet_path):
os.remove(self.local_parquet_path)

pq.write_table(
pyarrow.Table.from_pylist(
mapping=self.db_manager.select_as_list(
sa.text(self.create_query)
),
schema=self.parquet_schema,
),
self.local_parquet_path,
db_manager.write_to_parquet(
select_query=sa.text(self.create_query),
write_path=self.local_parquet_path,
schema=self.parquet_schema,
)

def update_parquet(self) -> bool:
def update_parquet(self, db_manager: DatabaseManager) -> bool:
download_file(
object_path=self.remote_parquet_path,
file_name=self.local_parquet_path,
Expand All @@ -62,9 +60,7 @@ def update_parquet(self) -> bool:
f"SELECT MAX(static_version_key) FROM {self.gtfs_table_name};"
)

max_db_key = self.db_manager.select_as_list(sa.text(max_key_query))[0][
"max"
]
max_db_key = db_manager.select_as_list(sa.text(max_key_query))[0]["max"]

# no update needed
if max_db_key <= max_parquet_key:
Expand All @@ -75,21 +71,31 @@ def update_parquet(self) -> bool:
f" WHERE static_version_key > {max_parquet_key} ",
)

pq.write_table(
pyarrow.concat_tables(
[
pq.read_table(self.local_parquet_path),
pyarrow.Table.from_pylist(
mapping=self.db_manager.select_as_list(
sa.text(update_query)
),
schema=self.parquet_schema,
),
]
),
self.local_parquet_path,
db_parquet_path = "/tmp/db_local.parquet"
db_manager.write_to_parquet(
select_query=sa.text(update_query),
write_path=db_parquet_path,
schema=self.parquet_schema,
)

old_ds = pd.dataset(self.local_parquet_path)
new_ds = pd.dataset(db_parquet_path)

combine_parquet_path = "/tmp/combine.parquet"
combine_batches = pd.dataset(
[old_ds, new_ds],
schema=self.parquet_schema,
).to_batches(batch_size=1024 * 1024)

with pq.ParquetWriter(
combine_parquet_path, schema=self.parquet_schema
) as writer:
for batch in combine_batches:
writer.write_batch(batch)

os.replace(combine_parquet_path, self.local_parquet_path)
os.remove(db_parquet_path)

return True


Expand Down
Loading

0 comments on commit bc28218

Please sign in to comment.