Skip to content

Commit

Permalink
FEAT: Convert Transit Master data to Bus Events
Browse files Browse the repository at this point in the history
We want to take Transit Master data from the springboard bucket and
convert them into Bus Events that can be joined against GTFS Realtime
data. This will give us additional information about when a but arrived
at stops as well as when it hits non revenue timepoints.

Create a new function that takes a list of Transit Master stop crossing
parquet paths and joins it against Transit Master Geo Nodes, Routes,
Trips, and Vehicle Tables. Adjust column names, cast them appropriately,
and do some modification to make them useable by later stages in the
pipeline.

Add test files and test cases to ensure ingestion and transformation is
happening as expected.
  • Loading branch information
mzappitello committed Aug 15, 2024
1 parent 7ebde2c commit 14682a5
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 0 deletions.
102 changes: 102 additions & 0 deletions src/lamp_py/bus_performance_manager/tm_ingestion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import List
from datetime import date

import pytz
import polars as pl

from lamp_py.runtime_utils.remote_files import RemoteFileLocations

BOSTON_TZ = pytz.timezone("EST5EDT")
UTC_TZ = pytz.utc


def create_dt_from_sam(
service_date_col: pl.Expr, sam_time_col: pl.Expr
) -> pl.Expr:
"""
add a seconds after midnight to a service date to create a datetime object.
seconds after midnight is in boston local time, convert it to utc.
"""
return (
service_date_col.cast(pl.Datetime) + pl.duration(seconds=sam_time_col)
).map_elements(
lambda x: BOSTON_TZ.localize(x).astimezone(UTC_TZ),
return_dtype=pl.Datetime,
)


def generate_tm_events(tm_files: List[str]) -> pl.DataFrame:
"""
build out events from transit master stop crossing data
"""
# the geo node id is the transit master key and the geo node abbr is the
# gtfs stop id
tm_geo_nodes = pl.scan_parquet(
RemoteFileLocations.tm_geo_node_file.get_s3_path()
).select(["GEO_NODE_ID", "GEO_NODE_ABBR"])

# the route id is the transit master key and the route abbr is the gtfs
# route id.
# NOTE: some of these route ids have leading zeros
tm_routes = pl.scan_parquet(
RemoteFileLocations.tm_route_file.get_s3_path()
).select(["ROUTE_ID", "ROUTE_ABBR"])

# the trip id is the transit master key and the trip serial number is the
# gtfs trip id.
tm_trips = pl.scan_parquet(
RemoteFileLocations.tm_trip_file.get_s3_path()
).select(["TRIP_ID", "TRIP_SERIAL_NUMBER"])

# the vehicle id is the transit master key and the property tag is the
# vehicle label
tm_vehicles = pl.scan_parquet(
RemoteFileLocations.tm_vehicle_file.get_s3_path()
).select(["VEHICLE_ID", "PROPERTY_TAG"])

# pull stop crossing information for a given service date and join it with
# other dataframes using the transit master keys.
#
# convert the calendar id to a date object
# remove leading zeros from route ids where they exist
# convert arrival and departure times to utc datetimes
# cast everything else as a string
tm_stop_crossings = (
pl.scan_parquet(tm_files)
.filter(
pl.col("ACT_ARRIVAL_TIME").is_not_null()
| pl.col("ACT_DEPARTURE_TIME").is_not_null()
)
.join(tm_geo_nodes, on="GEO_NODE_ID")
.join(tm_routes, on="ROUTE_ID")
.join(tm_trips, on="TRIP_ID", how="left", coalesce=True)
.join(tm_vehicles, on="VEHICLE_ID")
.select(
pl.col("CALENDAR_ID")
.cast(pl.Utf8)
.str.slice(1)
.str.strptime(pl.Date, format="%Y%m%d")
.alias("service_date"),
pl.col("ACT_ARRIVAL_TIME").alias("arrival_sam"),
pl.col("ACT_DEPARTURE_TIME").alias("departure_sam"),
pl.col("PROPERTY_TAG").cast(pl.String).alias("vehicle_label"),
pl.col("ROUTE_ABBR")
.cast(pl.String)
.str.strip_chars_start("0")
.alias("route_id"),
pl.col("GEO_NODE_ID").cast(pl.String).alias("geo_node_id"),
pl.col("GEO_NODE_ABBR").cast(pl.String).alias("stop_id"),
pl.col("TRIP_SERIAL_NUMBER").cast(pl.String).alias("trip_id"),
)
.with_columns(
create_dt_from_sam(
pl.col("service_date"), pl.col("arrival_sam")
).alias("arrival_tm"),
create_dt_from_sam(
pl.col("service_date"), pl.col("departure_sam")
).alias("departure_tm"),
)
.collect()
)

return tm_stop_crossings
83 changes: 83 additions & 0 deletions tests/bus_performance_manager/test_tm_ingestion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
from _pytest.monkeypatch import MonkeyPatch
from datetime import datetime

import polars as pl

from lamp_py.bus_performance_manager.tm_ingestion import generate_tm_events

from ..test_resources import LocalFileLocaions


def test_tm_to_bus_events(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(
"lamp_py.bus_performance_manager.tm_ingestion.RemoteFileLocations",
LocalFileLocaions,
)

tm_sc_dir = LocalFileLocaions.tm_stop_crossing.get_s3_path()
assert os.path.exists(tm_sc_dir)

for filename in os.listdir(tm_sc_dir):
check_stop_crossings(os.path.join(tm_sc_dir, filename))


def check_stop_crossings(stop_crossings_filepath: str) -> None:
# Remove the .parquet extension and get the date
filename = os.path.basename(stop_crossings_filepath)
date_str = filename.replace( ".parquet", "")[1:]
service_date = datetime.strptime(date_str, "%Y%m%d").date()

# this is the df of all useful records from the stop crossings files
raw_stop_crossings = (
pl.scan_parquet(stop_crossings_filepath)
.filter(
pl.col("ACT_ARRIVAL_TIME").is_not_null()
| pl.col("ACT_DEPARTURE_TIME").is_not_null()
)
.collect()
)

# run the generate tm events function on our input files
bus_events = generate_tm_events(tm_files=[stop_crossings_filepath])

# ensure data has been extracted from the filepath
assert not bus_events.is_empty()

# make sure we only have a single service date and it matches the filename service date
assert set(bus_events["service_date"]) == {service_date}

# ensure we didn't lose any data from the raw dataset when joining
assert len(bus_events) == len(raw_stop_crossings)

# check that crossings without trips are garage pullouts
bus_garages = {
"soham",
"lynn",
"prwb",
"charl",
"cabot",
"arbor",
"qubus",
"somvl",
}
non_trip_events = bus_events.filter(pl.col("trip_id").is_null())
assert set(non_trip_events["stop_id"]).issubset(bus_garages)

# check that all arrival and departure timestamps happen after the start of the service date
assert bus_events.filter(
(pl.col("arrival_tm") < service_date)
| (pl.col("departure_tm") < service_date)
).is_empty()

# check that all departure times are after the arrival times
assert bus_events.filter(
pl.col("arrival_tm") > pl.col("departure_tm")
).is_empty()

# check that there are no leading zeros on route ids
assert bus_events.filter(
pl.col("route_id").str.starts_with("0")
| pl.col("trip_id").str.starts_with("0")
| pl.col("stop_id").str.starts_with("0")
).is_empty()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit 14682a5

Please sign in to comment.