diff --git a/python_src/src/lamp_py/performance_manager/pipeline.py b/python_src/src/lamp_py/performance_manager/pipeline.py index 4e98e6a1..31073cb0 100755 --- a/python_src/src/lamp_py/performance_manager/pipeline.py +++ b/python_src/src/lamp_py/performance_manager/pipeline.py @@ -15,7 +15,7 @@ from lamp_py.runtime_utils.env_validation import validate_environment from lamp_py.runtime_utils.process_logger import ProcessLogger -# from lamp_py.tableau.pipeline import start_parquet_updates +from lamp_py.tableau.pipeline import start_parquet_updates from .flat_file import write_flat_files from .l0_gtfs_rt_events import process_gtfs_rt_files @@ -69,7 +69,7 @@ def iteration() -> None: process_static_tables(db_manager) process_gtfs_rt_files(db_manager) write_flat_files(db_manager) - # start_parquet_updates() + start_parquet_updates(db_manager) process_logger.log_complete() except Exception as exception: diff --git a/python_src/src/lamp_py/postgres/postgres_utils.py b/python_src/src/lamp_py/postgres/postgres_utils.py index 472ab2f6..c7b354c5 100644 --- a/python_src/src/lamp_py/postgres/postgres_utils.py +++ b/python_src/src/lamp_py/postgres/postgres_utils.py @@ -9,6 +9,8 @@ import pandas import sqlalchemy as sa from sqlalchemy.orm import sessionmaker +import pyarrow +import pyarrow.parquet as pq from lamp_py.aws.s3 import get_datetime_from_partition_path from lamp_py.runtime_utils.process_logger import ProcessLogger @@ -28,6 +30,13 @@ def running_in_docker() -> bool: ) +def running_in_aws() -> bool: + """ + return True if running on aws, else False + """ + return bool(os.getenv("AWS_DEFAULT_REGION")) + + def get_db_password() -> str: """ function to provide rds password @@ -86,12 +95,14 @@ def get_local_engine( db_user = os.environ.get("DB_USER") db_ssl_options = "" - # when using docker, the db host env var will be "local_rds" but - # accessed via the "0.0.0.0" ip address (mac specific) + # on mac, when running in docker locally db is accessed by "0.0.0.0" ip if db_host == "local_rds" and "macos" in platform.platform().lower(): db_host = "0.0.0.0" - # if not running_in_docker(): - # db_host = "127.0.0.1" + + # when running application locally in CLI for configuration + # and debugging, db is accessed by localhost ip + if not running_in_docker() and not running_in_aws(): + db_host = "127.0.0.1" assert db_host is not None assert db_name is not None @@ -266,6 +277,40 @@ def select_as_list( with self.session.begin() as cursor: return [row._asdict() for row in cursor.execute(select_query)] + def write_to_parquet( + self, + select_query: sa.sql.selectable.Select, + write_path: str, + schema: pyarrow.schema, + batch_size: int = 1024 * 1024, + ) -> None: + """ + stream db query results to parquet file in batches + + this function is meant to limit memory usage when creating very large + parquet files from db SELECT + + default batch_size of 1024*1024 is based on "row_group_size" parameter + of ParquetWriter.write_batch(): row group size will be the minimum of + the RecordBatch size and 1024 * 1024. If set larger + than 64Mi then 64Mi will be used instead. + https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetWriter.html#pyarrow.parquet.ParquetWriter.write_batch + + :param select_query: query to execute + :param write_path: local file path for resulting parquet file + :param schema: schema of parquet file from select query + :param batch_size: number of records to stream from db per batch + """ + with self.session.begin() as cursor: + result = cursor.execute(select_query).yield_per(batch_size) + with pq.ParquetWriter(write_path, schema=schema) as pq_writer: + for part in result.partitions(): + pq_writer.write_batch( + pyarrow.RecordBatch.from_pylist( + [row._asdict() for row in part], schema=schema + ) + ) + def truncate_table( self, table_to_truncate: Any, @@ -291,9 +336,7 @@ def truncate_table( self.execute(sa.text(f"{truncate_query};")) # Execute VACUUM to avoid non-deterministic behavior during testing - with self.session.begin() as cursor: - cursor.execute(sa.text("END TRANSACTION;")) - cursor.execute(sa.text(f"VACUUM (ANALYZE) {truncat_as};")) + self.vacuum_analyze(table_to_truncate) def vacuum_analyze(self, table: Any) -> None: """RUN VACUUM (ANALYZE) on table""" diff --git a/python_src/src/lamp_py/runtime_utils/env_validation.py b/python_src/src/lamp_py/runtime_utils/env_validation.py index 01875194..f923f4f2 100644 --- a/python_src/src/lamp_py/runtime_utils/env_validation.py +++ b/python_src/src/lamp_py/runtime_utils/env_validation.py @@ -6,6 +6,7 @@ def validate_environment( required_variables: List[str], + private_variables: Optional[List[str]] = None, optional_variables: Optional[List[str]] = None, validate_db: bool = False, ) -> None: @@ -16,6 +17,9 @@ def validate_environment( process_logger = ProcessLogger("validate_env") process_logger.log_start() + if private_variables is None: + private_variables = [] + # every pipeline needs a service name for logging required_variables.append("SERVICE_NAME") @@ -27,6 +31,10 @@ def validate_environment( "DB_PORT", "DB_USER", ] + # if db password is missing, db region is required to generate a + # token to use as the password to the cloud database + if os.environ.get("DB_PASSWORD", None) is None: + required_variables.append("DB_REGION") # check for missing variables. add found variables to our logs. missing_required = [] @@ -34,22 +42,19 @@ def validate_environment( value = os.environ.get(key, None) if value is None: missing_required.append(key) + # do not log private variables + if key in private_variables: + value = "**********" process_logger.add_metadata(**{key: value}) - # if db password is missing, db region is required to generate a token to - # use as the password to the cloud database - if validate_db: - if os.environ.get("DB_PASSWORD", None) is None: - value = os.environ.get("DB_REGION", None) - if value is None: - missing_required.append("DB_REGION") - process_logger.add_metadata(DB_REGION=value) - # for optional variables, access ones that exist and add them to logs. if optional_variables: for key in optional_variables: value = os.environ.get(key, None) if value is not None: + # do not log private variables + if key in private_variables: + value = "**********" process_logger.add_metadata(**{key: value}) # if required variables are missing, log a failure and throw. diff --git a/python_src/src/lamp_py/tableau/hyper.py b/python_src/src/lamp_py/tableau/hyper.py index 239f311d..a69d4808 100644 --- a/python_src/src/lamp_py/tableau/hyper.py +++ b/python_src/src/lamp_py/tableau/hyper.py @@ -56,8 +56,6 @@ def __init__( "s3://", "" ) - self.db_manager = DatabaseManager() - @property @abstractmethod def parquet_schema(self) -> pyarrow.schema: @@ -66,13 +64,13 @@ def parquet_schema(self) -> pyarrow.schema: """ @abstractmethod - def create_parquet(self) -> None: + def create_parquet(self, db_manager: DatabaseManager) -> None: """ Business logic to create new Job parquet file """ @abstractmethod - def update_parquet(self) -> bool: + def update_parquet(self, db_manager: DatabaseManager) -> bool: """ Business logic to update existing Job parquet file @@ -261,7 +259,7 @@ def run_hyper(self) -> None: if retry_count == max_retries - 1: process_log.log_failure(exception=exception) - def run_parquet(self) -> None: + def run_parquet(self, db_manager: DatabaseManager) -> None: """ Remote parquet Create / Update runner @@ -293,10 +291,10 @@ def run_parquet(self) -> None: # remote schema does not match expected local schema run_action = "create" upload_parquet = True - self.create_parquet() + self.create_parquet(db_manager) else: run_action = "update" - upload_parquet = self.update_parquet() + upload_parquet = self.update_parquet(db_manager) parquet_file_size_mb = os.path.getsize(self.local_parquet_path) / ( 1024 * 1024 diff --git a/python_src/src/lamp_py/tableau/jobs/gtfs_rail.py b/python_src/src/lamp_py/tableau/jobs/gtfs_rail.py index 9af2c382..a15a25e3 100644 --- a/python_src/src/lamp_py/tableau/jobs/gtfs_rail.py +++ b/python_src/src/lamp_py/tableau/jobs/gtfs_rail.py @@ -2,10 +2,12 @@ import pyarrow import pyarrow.parquet as pq +import pyarrow.dataset as pd import sqlalchemy as sa from lamp_py.tableau.hyper import HyperJob from lamp_py.aws.s3 import download_file +from lamp_py.postgres.postgres_utils import DatabaseManager class HyperGTFS(HyperJob): @@ -34,21 +36,17 @@ def __init__( def parquet_schema(self) -> pyarrow.schema: """Define GTFS Table Schema""" - def create_parquet(self) -> None: + def create_parquet(self, db_manager: DatabaseManager) -> None: if os.path.exists(self.local_parquet_path): os.remove(self.local_parquet_path) - pq.write_table( - pyarrow.Table.from_pylist( - mapping=self.db_manager.select_as_list( - sa.text(self.create_query) - ), - schema=self.parquet_schema, - ), - self.local_parquet_path, + db_manager.write_to_parquet( + select_query=sa.text(self.create_query), + write_path=self.local_parquet_path, + schema=self.parquet_schema, ) - def update_parquet(self) -> bool: + def update_parquet(self, db_manager: DatabaseManager) -> bool: download_file( object_path=self.remote_parquet_path, file_name=self.local_parquet_path, @@ -62,9 +60,7 @@ def update_parquet(self) -> bool: f"SELECT MAX(static_version_key) FROM {self.gtfs_table_name};" ) - max_db_key = self.db_manager.select_as_list(sa.text(max_key_query))[0][ - "max" - ] + max_db_key = db_manager.select_as_list(sa.text(max_key_query))[0]["max"] # no update needed if max_db_key <= max_parquet_key: @@ -75,21 +71,31 @@ def update_parquet(self) -> bool: f" WHERE static_version_key > {max_parquet_key} ", ) - pq.write_table( - pyarrow.concat_tables( - [ - pq.read_table(self.local_parquet_path), - pyarrow.Table.from_pylist( - mapping=self.db_manager.select_as_list( - sa.text(update_query) - ), - schema=self.parquet_schema, - ), - ] - ), - self.local_parquet_path, + db_parquet_path = "/tmp/db_local.parquet" + db_manager.write_to_parquet( + select_query=sa.text(update_query), + write_path=db_parquet_path, + schema=self.parquet_schema, ) + old_ds = pd.dataset(self.local_parquet_path) + new_ds = pd.dataset(db_parquet_path) + + combine_parquet_path = "/tmp/combine.parquet" + combine_batches = pd.dataset( + [old_ds, new_ds], + schema=self.parquet_schema, + ).to_batches(batch_size=1024 * 1024) + + with pq.ParquetWriter( + combine_parquet_path, schema=self.parquet_schema + ) as writer: + for batch in combine_batches: + writer.write_batch(batch) + + os.replace(combine_parquet_path, self.local_parquet_path) + os.remove(db_parquet_path) + return True diff --git a/python_src/src/lamp_py/tableau/jobs/rt_rail.py b/python_src/src/lamp_py/tableau/jobs/rt_rail.py index ee931c0d..27449876 100644 --- a/python_src/src/lamp_py/tableau/jobs/rt_rail.py +++ b/python_src/src/lamp_py/tableau/jobs/rt_rail.py @@ -2,10 +2,13 @@ import pyarrow import pyarrow.parquet as pq +import pyarrow.compute as pc +import pyarrow.dataset as pd import sqlalchemy as sa from lamp_py.tableau.hyper import HyperJob from lamp_py.aws.s3 import download_file +from lamp_py.postgres.postgres_utils import DatabaseManager class HyperRtRail(HyperJob): @@ -134,21 +137,19 @@ def parquet_schema(self) -> pyarrow.schema: ] ) - def create_parquet(self) -> None: + def create_parquet(self, db_manager: DatabaseManager) -> None: create_query = self.table_query % "" if os.path.exists(self.local_parquet_path): os.remove(self.local_parquet_path) - pq.write_table( - pyarrow.Table.from_pylist( - mapping=self.db_manager.select_as_list(sa.text(create_query)), - schema=self.parquet_schema, - ), - self.local_parquet_path, + db_manager.write_to_parquet( + select_query=sa.text(create_query), + write_path=self.local_parquet_path, + schema=self.parquet_schema, ) - def update_parquet(self) -> bool: + def update_parquet(self, db_manager: DatabaseManager) -> bool: download_file( object_path=self.remote_parquet_path, file_name=self.local_parquet_path, @@ -162,22 +163,44 @@ def update_parquet(self) -> bool: f" AND vt.service_date >= {max_start_date} ", ) - pq.write_table( - pyarrow.concat_tables( - [ - pq.read_table( - self.local_parquet_path, - filters=[("service_date", "<", max_start_date)], - ), - pyarrow.Table.from_pylist( - mapping=self.db_manager.select_as_list( - sa.text(update_query) - ), - schema=self.parquet_schema, - ), - ] - ), - self.local_parquet_path, + db_parquet_path = "/tmp/db_local.parquet" + db_manager.write_to_parquet( + select_query=sa.text(update_query), + write_path=db_parquet_path, + schema=self.parquet_schema, ) + # update downloaded parquet file with filtered service_date + old_filter = pc.field("service_date") < max_start_date + old_batches = pd.dataset(self.local_parquet_path).to_batches( + filter=old_filter, batch_size=1024 * 1024 + ) + filter_path = "/tmp/filter_local.parquet" + with pq.ParquetWriter( + filter_path, schema=self.parquet_schema + ) as writer: + for batch in old_batches: + writer.write_batch(batch) + os.replace(filter_path, self.local_parquet_path) + + joined_dataset = [ + pd.dataset(self.local_parquet_path), + pd.dataset(db_parquet_path), + ] + + combine_parquet_path = "/tmp/combine.parquet" + combine_batches = pd.dataset( + joined_dataset, + schema=self.parquet_schema, + ).to_batches(batch_size=1024 * 1024) + + with pq.ParquetWriter( + combine_parquet_path, schema=self.parquet_schema + ) as writer: + for batch in combine_batches: + writer.write_batch(batch) + + os.replace(combine_parquet_path, self.local_parquet_path) + os.remove(db_parquet_path) + return True diff --git a/python_src/src/lamp_py/tableau/pipeline.py b/python_src/src/lamp_py/tableau/pipeline.py index b9e0b825..8184e404 100644 --- a/python_src/src/lamp_py/tableau/pipeline.py +++ b/python_src/src/lamp_py/tableau/pipeline.py @@ -4,6 +4,7 @@ from lamp_py.runtime_utils.env_validation import validate_environment from lamp_py.tableau.hyper import HyperJob +from lamp_py.postgres.postgres_utils import DatabaseManager from lamp_py.tableau.jobs.rt_rail import HyperRtRail from lamp_py.tableau.jobs.gtfs_rail import ( HyperServiceIdByRoute, @@ -12,7 +13,7 @@ HyperStaticFeedInfo, HyperStaticRoutes, HyperStaticStops, - # HyperStaticStopTimes, + HyperStaticStopTimes, HyperStaticTrips, ) @@ -27,7 +28,7 @@ def create_hyper_jobs() -> List[HyperJob]: HyperStaticFeedInfo(), HyperStaticRoutes(), HyperStaticStops(), - # HyperStaticStopTimes(), + HyperStaticStopTimes(), HyperStaticTrips(), ] @@ -43,12 +44,15 @@ def start_hyper_updates() -> None: "TABLEAU_SERVER", "PUBLIC_ARCHIVE_BUCKET", ], + private_variables=[ + "TABLEAU_PASSWORD", + ], ) for job in create_hyper_jobs(): job.run_hyper() -def start_parquet_updates() -> None: +def start_parquet_updates(db_manager: DatabaseManager) -> None: """Run all Parquet Update jobs""" for job in create_hyper_jobs(): - job.run_parquet() + job.run_parquet(db_manager)