Skip to content

Commit 21571b7

Browse files
authored
Merge pull request #92 from NatLabRockies/copilot/better-route-filtering-handling
Default to trip-level route filtering; use block-level only when deadhead is requested
2 parents fc834ca + e10e9ae commit 21571b7

4 files changed

Lines changed: 207 additions & 27 deletions

File tree

docs/index.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,17 @@ predictor.run(
3434
)
3535
```
3636

37+
By default, route filtering works at the **trip level**, so individual trips on the requested routes are always included even if their block also serves other routes. If you enable deadhead trip estimation, filtering automatically switches to **block level** to ensure complete blocks (see [](prediction) for details):
38+
39+
```python
40+
predictor.run(
41+
date="2023/08/02",
42+
routes=["806", "807"],
43+
add_mid_block_deadhead=True,
44+
add_depot_deadhead=True,
45+
)
46+
```
47+
3748
For a full example, see [](examples/Utah_Transit_Agency_example).
3849

3950
## Available Models

docs/prediction.md

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,25 @@ The full workflow proceeds in four stages:
1313

1414
RouteE-Transit reads a standard GTFS feed directory and loads trips along with their shape traces, stop locations, and stop times. Users can optionally filter to a specific service date and/or a subset of routes. The shape traces and scheduled stop times are used downstream to estimate average speed and distance at the road link level.
1515

16+
### Route filtering
17+
18+
By default, route filtering is **trip-level**: only trips whose `route_short_name` appears in the `routes` list are included, regardless of what other routes share the same GTFS block. This gives the most intuitive results, especially for agencies where interlining is common (e.g., a bus block might serve both route 5 and route 21).
19+
20+
When deadhead trips are requested (`add_mid_block_deadhead=True` or `add_depot_deadhead=True`), filtering automatically switches to **block-level** mode. In block-level mode, entire blocks are excluded if *any* trip in the block belongs to a route not in the requested set. This is necessary because deadhead estimation requires complete blocks. If block-level filtering removes all trips but trip-level filtering would have kept some, the error message will explain this and suggest alternatives (either disabling deadhead or adding the additional interlined routes to the `routes` list).
21+
22+
When using `filter_trips()` directly, you can control this behavior with the `use_block_filter` parameter:
23+
24+
```python
25+
predictor.load_gtfs_data()
26+
27+
# Trip-level filtering (default) — individual trips on route "5" are
28+
# always kept, even if their block also serves route "21"
29+
predictor.filter_trips(routes=["5"])
30+
31+
# Block-level filtering — only blocks that exclusively serve route "5"
32+
predictor.filter_trips(routes=["5"], use_block_filter=True)
33+
```
34+
1635
See [](data:gtfs-reqs) for the full list of required GTFS files and fields.
1736

1837
## 2) Deadhead Trip Inference
@@ -77,6 +96,15 @@ predictor = GTFSEnergyPredictor(
7796
)
7897

7998
# Option 1: Use the convenience method (recommended)
99+
# By default, only revenue trips are included (no deadhead).
100+
trip_results = predictor.run(
101+
date="2023/08/02",
102+
routes=["205"],
103+
)
104+
105+
# Option 2: Include deadhead trips and HVAC impacts
106+
# When deadhead is enabled with route filtering, block-level filtering
107+
# is used automatically to ensure complete blocks.
80108
trip_results = predictor.run(
81109
date="2023/08/02",
82110
routes=["205"],
@@ -85,7 +113,7 @@ trip_results = predictor.run(
85113
add_hvac=True,
86114
)
87115

88-
# Option 2: Step-by-step processing for more control
116+
# Option 3: Step-by-step processing for more control
89117
predictor.load_gtfs_data()
90118
predictor.filter_trips(date="2023/08/02", routes=["205"])
91119
predictor.add_mid_block_deadhead() # Between-trip deadhead

routee/transit/predictor.py

Lines changed: 82 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,8 @@ def run(
228228
date: str | None = None,
229229
routes: list[str] | None = None,
230230
# Processing options
231-
add_mid_block_deadhead: bool = True,
232-
add_depot_deadhead: bool = True,
231+
add_mid_block_deadhead: bool = False,
232+
add_depot_deadhead: bool = False,
233233
# Energy prediction options
234234
add_hvac: bool = True,
235235
save_results: bool = True,
@@ -254,11 +254,16 @@ def run(
254254
If None, all trips across all service dates are included.
255255
routes : list[str], optional
256256
Filter trips to specific route IDs. If None, all routes are included.
257-
add_mid_block_deadhead : bool, default=True
257+
add_mid_block_deadhead : bool, default=False
258258
Whether to add deadhead trips between consecutive revenue trips.
259-
add_depot_deadhead : bool, default=True
259+
When True and ``routes`` is specified, block-level filtering is used
260+
to ensure only blocks that exclusively serve the selected routes are
261+
included (required for correct deadhead estimation).
262+
add_depot_deadhead : bool, default=False
260263
Whether to add deadhead trips from/to depots at start/end of blocks.
261264
Requires depot_path to be set during initialization.
265+
When True and ``routes`` is specified, block-level filtering is used
266+
(see ``add_mid_block_deadhead``).
262267
add_hvac : bool, default=True
263268
Whether to add HVAC energy consumption based on ambient temperature.
264269
save_results : bool, default=True
@@ -306,8 +311,16 @@ def run(
306311
self.load_gtfs_data()
307312

308313
# Step 2: Filter trips if requested
314+
# Use block-level filtering when deadhead trips are requested, because
315+
# deadhead estimation requires complete blocks. Otherwise, use the
316+
# more intuitive trip-level filtering.
317+
needs_deadhead = add_mid_block_deadhead or add_depot_deadhead
309318
if date is not None or routes is not None:
310-
self.filter_trips(date=date, routes=routes)
319+
self.filter_trips(
320+
date=date,
321+
routes=routes,
322+
use_block_filter=needs_deadhead and routes is not None,
323+
)
311324

312325
# Add start time, end time, and duration of each trip
313326
self.add_trip_times()
@@ -553,25 +566,44 @@ def filter_trips(
553566
self,
554567
date: str | None = None,
555568
routes: list[str] | None = None,
569+
use_block_filter: bool = False,
556570
) -> "GTFSEnergyPredictor":
557571
"""
558572
Filter trips by date and/or routes.
559573
560574
This method can be called after load_gtfs_data() to restrict the analysis
561575
to specific dates or routes. Can be called multiple times to refine filters.
562576
563-
Args:
564-
date: Date to filter trips (format: "YYYY-MM-DD" or datetime object).
565-
If None, keeps all currently loaded trips.
566-
routes: List of route_short_name values to filter by.
567-
If None, keeps all currently loaded routes.
568-
569-
Returns:
570-
Self for method chaining
577+
Parameters
578+
----------
579+
date : str, optional
580+
Date to filter trips (format: "YYYY-MM-DD" or datetime object).
581+
If None, keeps all currently loaded trips.
582+
routes : list[str], optional
583+
List of route_short_name values to filter by.
584+
If None, keeps all currently loaded routes.
585+
use_block_filter : bool, default=False
586+
When True, uses block-level filtering via
587+
``filter_blocks_by_route`` with ``route_method="exclusive"``.
588+
This means entire blocks are excluded if any trip in the block
589+
belongs to a route not in ``routes``. This is appropriate when
590+
deadhead trips are being estimated, because we need complete
591+
blocks. When False (the default), trips are filtered purely
592+
at the trip level so that individual trips on the requested
593+
routes are always included regardless of what other routes
594+
share the same block.
571595
572-
Raises:
573-
RuntimeError: If GTFS data hasn't been loaded yet
574-
ValueError: If no trips match the specified filters
596+
Returns
597+
-------
598+
GTFSEnergyPredictor
599+
Self for method chaining.
600+
601+
Raises
602+
------
603+
RuntimeError
604+
If GTFS data hasn't been loaded yet.
605+
ValueError
606+
If no trips match the specified filters.
575607
"""
576608
if self.feed is None or self.trips.empty:
577609
raise RuntimeError("Must call load_gtfs_data() before filtering trips")
@@ -588,14 +620,40 @@ def filter_trips(
588620

589621
# Filter by routes
590622
if routes is not None:
591-
self.trips = filter_blocks_by_route(
592-
trips=self.trips,
593-
routes=routes,
594-
route_column="route_short_name",
595-
route_method="exclusive",
596-
)
597-
if len(self.trips) == 0:
598-
raise ValueError("No trips found for the selected routes and date.")
623+
if use_block_filter:
624+
pre_filter_trips = self.trips
625+
self.trips = filter_blocks_by_route(
626+
trips=self.trips,
627+
routes=routes,
628+
route_column="route_short_name",
629+
route_method="exclusive",
630+
)
631+
if len(self.trips) == 0:
632+
# Check whether trip-level filtering would have kept any
633+
# trips. This tells the user whether the issue is that
634+
# no trips match the routes at all, or that block-level
635+
# filtering is too restrictive.
636+
trip_level_count = int(
637+
pre_filter_trips["route_short_name"].isin(routes).sum()
638+
)
639+
if trip_level_count > 0:
640+
raise ValueError(
641+
f"No trips remain after block-level route filtering, "
642+
f"but {trip_level_count} trip(s) match at the trip "
643+
f"level. This can happen when blocks contain trips "
644+
f"from routes not in the requested set (e.g. "
645+
f"interlined routes). Consider running without "
646+
f"deadhead trips to use trip-level filtering, or "
647+
f"add the additional routes to the 'routes' "
648+
f"parameter."
649+
)
650+
raise ValueError("No trips found for the selected routes and date.")
651+
else:
652+
self.trips = self.trips[
653+
self.trips["route_short_name"].isin(routes)
654+
].copy()
655+
if len(self.trips) == 0:
656+
raise ValueError("No trips found for the selected routes and date.")
599657

600658
# Update shapes to match filtered trips
601659
shape_ids = self.trips.shape_id.unique()

tests/test_predictor.py

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ def test_load_gtfs_data(self, mock_from_dir: MagicMock) -> None:
4646
self.assertEqual(self.predictor.trips.iloc[0]["trip_id"], "T1")
4747
self.assertEqual(len(self.predictor.shapes), 1)
4848

49-
@patch("routee.transit.predictor.filter_blocks_by_route")
50-
def test_filter_trips(self, mock_filter: MagicMock) -> None:
49+
def test_filter_trips_by_date(self) -> None:
5150
# Setup pre-loaded state
5251
self.predictor.feed = MagicMock()
5352
self.predictor.feed.get_service_ids_from_date.return_value = ["S1"]
@@ -65,6 +64,90 @@ def test_filter_trips(self, mock_filter: MagicMock) -> None:
6564
self.assertEqual(len(self.predictor.trips), 1)
6665
self.assertEqual(self.predictor.trips.iloc[0]["service_id"], "S1")
6766

67+
def test_filter_trips_trip_level_default(self) -> None:
68+
"""Trip-level filtering keeps individual trips even if the block has other routes."""
69+
self.predictor.feed = MagicMock()
70+
self.predictor.trips = pd.DataFrame(
71+
{
72+
"trip_id": ["T1", "T2", "T3"],
73+
"block_id": ["B1", "B1", "B2"],
74+
"route_short_name": ["R1", "R2", "R1"],
75+
"service_id": ["S1", "S1", "S1"],
76+
"shape_id": ["SH1", "SH2", "SH3"],
77+
}
78+
)
79+
self.predictor.feed.shapes = pd.DataFrame({"shape_id": ["SH1", "SH2", "SH3"]})
80+
81+
self.predictor.filter_trips(routes=["R1"])
82+
83+
# Should keep T1 and T3 (both on R1) even though T1's block also has R2
84+
self.assertEqual(len(self.predictor.trips), 2)
85+
self.assertListEqual(
86+
sorted(self.predictor.trips["trip_id"].tolist()), ["T1", "T3"]
87+
)
88+
89+
def test_filter_trips_block_level(self) -> None:
90+
"""Block-level filtering excludes blocks that contain trips from other routes."""
91+
self.predictor.feed = MagicMock()
92+
self.predictor.trips = pd.DataFrame(
93+
{
94+
"trip_id": ["T1", "T2", "T3"],
95+
"block_id": ["B1", "B1", "B2"],
96+
"route_short_name": ["R1", "R2", "R1"],
97+
"service_id": ["S1", "S1", "S1"],
98+
"shape_id": ["SH1", "SH2", "SH3"],
99+
}
100+
)
101+
self.predictor.feed.shapes = pd.DataFrame({"shape_id": ["SH1", "SH2", "SH3"]})
102+
103+
self.predictor.filter_trips(routes=["R1"], use_block_filter=True)
104+
105+
# Block B1 has R2 so it is excluded entirely; only B2/T3 remains
106+
self.assertEqual(len(self.predictor.trips), 1)
107+
self.assertEqual(self.predictor.trips.iloc[0]["trip_id"], "T3")
108+
109+
def test_filter_trips_block_level_error_message(self) -> None:
110+
"""Block-level filtering gives a helpful error when all trips are excluded."""
111+
self.predictor.feed = MagicMock()
112+
self.predictor.trips = pd.DataFrame(
113+
{
114+
"trip_id": ["T1", "T2"],
115+
"block_id": ["B1", "B1"],
116+
"route_short_name": ["R1", "R2"],
117+
"service_id": ["S1", "S1"],
118+
"shape_id": ["SH1", "SH2"],
119+
}
120+
)
121+
self.predictor.feed.shapes = pd.DataFrame({"shape_id": ["SH1", "SH2"]})
122+
123+
with self.assertRaises(ValueError) as ctx:
124+
self.predictor.filter_trips(routes=["R1"], use_block_filter=True)
125+
126+
self.assertIn("block-level", str(ctx.exception))
127+
self.assertIn("interlined", str(ctx.exception))
128+
# Should mention the count of trip-level matches
129+
self.assertIn("1 trip(s) match at the trip level", str(ctx.exception))
130+
131+
def test_filter_trips_block_level_no_trips_at_all(self) -> None:
132+
"""Block-level filtering with no matching trips gives a generic error."""
133+
self.predictor.feed = MagicMock()
134+
self.predictor.trips = pd.DataFrame(
135+
{
136+
"trip_id": ["T1", "T2"],
137+
"block_id": ["B1", "B1"],
138+
"route_short_name": ["R1", "R2"],
139+
"service_id": ["S1", "S1"],
140+
"shape_id": ["SH1", "SH2"],
141+
}
142+
)
143+
self.predictor.feed.shapes = pd.DataFrame({"shape_id": ["SH1", "SH2"]})
144+
145+
with self.assertRaises(ValueError) as ctx:
146+
self.predictor.filter_trips(routes=["R99"], use_block_filter=True)
147+
148+
# No trip-level matches either, so the generic error should be used
149+
self.assertIn("No trips found", str(ctx.exception))
150+
68151
def test_add_trip_times(self) -> None:
69152
# Setup pre-loaded state
70153
self.predictor.feed = MagicMock()

0 commit comments

Comments
 (0)