Skip to content

Commit

Permalink
Changes from PR152 squash before merge
Browse files Browse the repository at this point in the history
* change s3 utility for parsing timestamps to return a datetime instead
  of a timestamp
* get routes to prefilter from filepaths at readtime rather than when
  compiling a list of files to process
* add gtfs utility for getting a static version key from a service date
  and reuse it in the new function and in the existing one
  • Loading branch information
mzappitello committed Jul 18, 2023
1 parent 978fd85 commit f389579
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 92 deletions.
8 changes: 4 additions & 4 deletions python_src/src/lamp_py/aws/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def write_parquet_file(
process_logger.log_complete()


def get_utc_from_partition_path(path: str) -> float:
def get_datetime_from_partition_path(path: str) -> datetime.datetime:
"""
process datetime from partitioned s3 path return UTC timestamp
"""
Expand All @@ -289,17 +289,17 @@ def get_utc_from_partition_path(path: str) -> float:
month = int(re.findall(r"month=(\d{1,2})", path)[0])
day = int(re.findall(r"day=(\d{1,2})", path)[0])
hour = int(re.findall(r"hour=(\d{1,2})", path)[0])
date = datetime.datetime(
return_date = datetime.datetime(
year=year,
month=month,
day=day,
hour=hour,
tzinfo=datetime.timezone.utc,
)
return_date = datetime.datetime.timestamp(date)
except IndexError as _:
# handle gtfs static paths
return_date = float(re.findall(r"timestamp=(\d{10})", path)[0])
timestamp = float(re.findall(r"timestamp=(\d{10})", path)[0])
return_date = datetime.datetime.fromtimestamp(timestamp)
return return_date


Expand Down
144 changes: 80 additions & 64 deletions python_src/src/lamp_py/performance_manager/gtfs_utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from typing import Optional, List
import datetime
from typing import Optional, List, Union

import numpy
import pandas
import sqlalchemy as sa

from lamp_py.postgres.postgres_utils import DatabaseManager
from lamp_py.postgres.postgres_schema import (
ServiceIdDates,
StaticFeedInfo,
StaticRoutes,
StaticStops,
)
from lamp_py.runtime_utils.process_logger import ProcessLogger
from lamp_py.aws.s3 import get_datetime_from_partition_path


def start_time_to_seconds(
Expand Down Expand Up @@ -41,6 +40,66 @@ def unique_trip_stop_columns() -> List[str]:
]


def static_version_key_from_service_date(
service_date: int, db_manager: DatabaseManager
) -> int:
"""
for a given service date, determine the correct static schedule to use
"""
# the service date must:
# * be between "feed_start_date" and "feed_end_date" in StaticFeedInfo
# * be less than or equal to "feed_active_date" in StaticFeedInfo
#
# order all static version keys by feed_active_date descending and
# created_on date descending, then choose the first tone. this handles
# multiple static schedules being issued for the same service day
live_match_query = (
sa.select(StaticFeedInfo.static_version_key)
.where(
StaticFeedInfo.feed_start_date <= service_date,
StaticFeedInfo.feed_end_date >= service_date,
StaticFeedInfo.feed_active_date <= service_date,
)
.order_by(
StaticFeedInfo.feed_active_date.desc(),
StaticFeedInfo.created_on.desc(),
)
.limit(1)
)

# "feed_start_date" and "feed_end_date" are modified for archived GTFS
# Schedule files. If processing archived static schedules, these alternate
# rules must be used for matching GTFS static to GTFS-RT data
archive_match_query = (
sa.select(StaticFeedInfo.static_version_key)
.where(
StaticFeedInfo.feed_start_date <= service_date,
StaticFeedInfo.feed_end_date >= service_date,
)
.order_by(
StaticFeedInfo.feed_start_date.desc(),
StaticFeedInfo.created_on.desc(),
)
.limit(1)
)

result = db_manager.select_as_list(live_match_query)

# if live_match_query fails, attempt to look for a match using the archive method
if len(result) == 0:
result = db_manager.select_as_list(archive_match_query)

# if this query does not produce a result, no static schedule info
# exists for this trip update data, so the data
# should not be processed until valid static schedule data exists
if len(result) == 0:
raise IndexError(
f"StaticFeedInfo table has no matching schedule for service_date={service_date}"
)

return int(result[0]["static_version_key"])


def add_static_version_key_column(
events_dataframe: pandas.DataFrame,
db_manager: DatabaseManager,
Expand Down Expand Up @@ -70,60 +129,15 @@ def add_static_version_key_column(
events_dataframe["static_version_key"] = 0

for date in events_dataframe["service_date"].unique():
date = int(date)
# "service_date" from events dataframe must be between "feed_start_date" and "feed_end_date" in StaticFeedInfo
# "service_date" must also be less than or equal to "feed_active_date" in StaticFeedInfo
# StaticFeedInfo, order by feed_active_date descending and created_on date descending
# this should deal with multiple static schedules being issued on the same day
# if this occurs we will use the latest issued schedule
live_match_query = (
sa.select(StaticFeedInfo.static_version_key)
.where(
StaticFeedInfo.feed_start_date <= date,
StaticFeedInfo.feed_end_date >= date,
StaticFeedInfo.feed_active_date <= date,
)
.order_by(
StaticFeedInfo.feed_active_date.desc(),
StaticFeedInfo.created_on.desc(),
)
.limit(1)
)

# "feed_start_date" and "feed_end_date" are modified for archived GTFS Schedule files
# If processing archived static schedules, these alternate rules must be used for matching
# GTFS static to GTFS-RT data
archive_match_query = (
sa.select(StaticFeedInfo.static_version_key)
.where(
StaticFeedInfo.feed_start_date <= date,
StaticFeedInfo.feed_end_date >= date,
)
.order_by(
StaticFeedInfo.feed_start_date.desc(),
StaticFeedInfo.created_on.desc(),
)
.limit(1)
service_date = int(date)
static_version_key = static_version_key_from_service_date(
service_date=service_date, db_manager=db_manager
)

result = db_manager.select_as_list(live_match_query)

# if live_match_query fails, attempt to look for a match using the archive method
if len(result) == 0:
result = db_manager.select_as_list(archive_match_query)

# if this query does not produce a result, no static schedule info
# exists for this trip update data, so the data
# should not be processed until valid static schedule data exists
if len(result) == 0:
raise IndexError(
f"StaticFeedInfo table has no matching schedule for service_date={date}"
)

service_date_mask = events_dataframe["service_date"] == date
events_dataframe.loc[service_date_mask, "static_version_key"] = int(
result[0]["static_version_key"]
)
service_date_mask = events_dataframe["service_date"] == service_date
events_dataframe.loc[
service_date_mask, "static_version_key"
] = static_version_key

process_logger.log_complete()

Expand Down Expand Up @@ -185,8 +199,8 @@ def add_parent_station_column(
return events_dataframe


def rail_routes_from_timestamp(
timestamp: float, db_manager: DatabaseManager
def rail_routes_from_filepath(
filepath: Union[List[str], str], db_manager: DatabaseManager
) -> List[str]:
"""
get a list of rail route_ids that were in effect on a given service date
Expand All @@ -196,18 +210,20 @@ def rail_routes_from_timestamp(
key for a given service date, using the key with the max value (the keys
are also timestamps). then pull all the static routes with type
"""
date = datetime.datetime.utcfromtimestamp(timestamp)
service_date = f"{date.year:04}{date.month:02}{date.day:02}"
if isinstance(filepath, list):
filepath = filepath[0]

date = get_datetime_from_partition_path(filepath)
service_date = int(f"{date.year:04}{date.month:02}{date.day:02}")

svk_subquery = (
sa.select(sa.func.max(ServiceIdDates.static_version_key))
.where(ServiceIdDates.service_date == service_date)
.scalar_subquery()
static_version_key = static_version_key_from_service_date(
service_date=service_date, db_manager=db_manager
)

result = db_manager.execute(
sa.select(StaticRoutes.route_id).where(
StaticRoutes.route_type.in_([0, 1, 2]),
StaticRoutes.static_version_key == svk_subquery,
StaticRoutes.static_version_key == static_version_key,
)
)

Expand Down
18 changes: 8 additions & 10 deletions python_src/src/lamp_py/performance_manager/l0_gtfs_rt_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sqlalchemy.dialects import postgresql

from lamp_py.aws.ecs import check_for_sigterm
from lamp_py.aws.s3 import get_utc_from_partition_path
from lamp_py.aws.s3 import get_datetime_from_partition_path
from lamp_py.postgres.postgres_schema import (
MetadataLog,
TempEventCompare,
Expand All @@ -19,7 +19,7 @@
)
from lamp_py.runtime_utils.process_logger import ProcessLogger

from .gtfs_utils import unique_trip_stop_columns, rail_routes_from_timestamp
from .gtfs_utils import unique_trip_stop_columns
from .l0_rt_trip_updates import process_tu_files
from .l0_rt_vehicle_positions import process_vp_files
from .l1_rt_trips import process_trips, load_new_trip_data
Expand All @@ -38,18 +38,21 @@ def get_gtfs_rt_paths(db_manager: DatabaseManager) -> List[Dict[str, List]]:

vp_files = get_unprocessed_files("RT_VEHICLE_POSITIONS", db_manager)
for record in vp_files:
timestamp = get_utc_from_partition_path(record["paths"][0])
timestamp = get_datetime_from_partition_path(
record["paths"][0]
).timestamp()

grouped_files[timestamp] = {
"ids": record["ids"],
"vp_paths": record["paths"],
"tu_paths": [],
"route_ids": rail_routes_from_timestamp(timestamp, db_manager),
}

tu_files = get_unprocessed_files("RT_TRIP_UPDATES", db_manager)
for record in tu_files:
timestamp = get_utc_from_partition_path(record["paths"][0])
timestamp = get_datetime_from_partition_path(
record["paths"][0]
).timestamp()
if timestamp in grouped_files:
grouped_files[timestamp]["ids"] += record["ids"]
grouped_files[timestamp]["tu_paths"] += record["paths"]
Expand All @@ -58,7 +61,6 @@ def get_gtfs_rt_paths(db_manager: DatabaseManager) -> List[Dict[str, List]]:
"ids": record["ids"],
"tu_paths": record["paths"],
"vp_paths": [],
"route_ids": rail_routes_from_timestamp(timestamp, db_manager),
}

process_logger.add_metadata(hours_found=len(grouped_files))
Expand Down Expand Up @@ -472,15 +474,13 @@ def process_gtfs_rt_files(db_manager: DatabaseManager) -> None:
# all events come from vp files. add tu key afterwards.
events = process_vp_files(
paths=files["vp_paths"],
route_ids=files["route_ids"],
db_manager=db_manager,
)
events["tu_stop_timestamp"] = None
elif len(files["vp_paths"]) == 0:
# all events come from tu files. add vp keys afterwards.
events = process_tu_files(
paths=files["tu_paths"],
route_ids=files["route_ids"],
db_manager=db_manager,
)
events["vp_move_timestamp"] = None
Expand All @@ -489,12 +489,10 @@ def process_gtfs_rt_files(db_manager: DatabaseManager) -> None:
# events come from tu and vp files. join them together.
vp_events = process_vp_files(
paths=files["vp_paths"],
route_ids=files["route_ids"],
db_manager=db_manager,
)
tu_events = process_tu_files(
paths=files["tu_paths"],
route_ids=files["route_ids"],
db_manager=db_manager,
)
events = combine_events(vp_events, tu_events)
Expand Down
12 changes: 7 additions & 5 deletions python_src/src/lamp_py/performance_manager/l0_rt_trip_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,19 @@
add_static_version_key_column,
add_parent_station_column,
unique_trip_stop_columns,
rail_routes_from_filepath,
)


def get_tu_dataframe_chunks(
to_load: Union[str, List[str]], route_ids: List[str]
to_load: Union[str, List[str]], db_manager: DatabaseManager
) -> Iterator[pandas.DataFrame]:
"""
return interator of dataframe chunks from a trip updates parquet file
(or list of files)
"""
route_ids = rail_routes_from_filepath(to_load, db_manager)

trip_update_columns = [
"timestamp",
"stop_time_update",
Expand Down Expand Up @@ -108,7 +111,7 @@ def explode_stop_time_update(


def get_and_unwrap_tu_dataframe(
paths: Union[str, List[str]], route_ids: List[str]
paths: Union[str, List[str]], db_manager: DatabaseManager
) -> pandas.DataFrame:
"""
unwrap and explode trip updates records from parquet files
Expand All @@ -124,7 +127,7 @@ def get_and_unwrap_tu_dataframe(
# per batch, this should result in ~5-6 GB of memory use per batch
# after batch goes through explod_stop_time_update vectorize operation,
# resulting Series has negligible memory use
for batch_events in get_tu_dataframe_chunks(paths, route_ids):
for batch_events in get_tu_dataframe_chunks(paths, db_manager):
# store start_date as int64 and rename to service_date
batch_events.rename(
columns={"start_date": "service_date"}, inplace=True
Expand Down Expand Up @@ -213,7 +216,6 @@ def reduce_trip_updates(trip_updates: pandas.DataFrame) -> pandas.DataFrame:

def process_tu_files(
paths: Union[str, List[str]],
route_ids: List[str],
db_manager: DatabaseManager,
) -> pandas.DataFrame:
"""
Expand All @@ -224,7 +226,7 @@ def process_tu_files(
)
process_logger.log_start()

trip_updates = get_and_unwrap_tu_dataframe(paths, route_ids)
trip_updates = get_and_unwrap_tu_dataframe(paths, db_manager)
if trip_updates.shape[0] > 0:
trip_updates = add_static_version_key_column(trip_updates, db_manager)
trip_updates = add_parent_station_column(trip_updates, db_manager)
Expand Down
Loading

0 comments on commit f389579

Please sign in to comment.