Skip to content

Commit 4bc1f7d

Browse files
committed
Make price retrieval more robust
1 parent 347c2aa commit 4bc1f7d

3 files changed

Lines changed: 137 additions & 31 deletions

File tree

app/util.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
from util.oebb import (
1313
get_station_details,
1414
get_travel_action_id,
15-
get_connection_ids,
16-
get_price_for_connection,
15+
get_price_for_route,
1716
)
1817

1918

@@ -117,25 +116,20 @@ def get_price_generator(
117116
current_step += 1
118117
current_message = _("Processing connections")
119118
yield render(current_message, current_step, total_steps)
120-
connection_ids = get_connection_ids(
121-
travel_action_id,
122-
origin_details,
123-
destination_details,
119+
120+
current_step += 1
121+
current_message = _("Retrieving price")
122+
yield render(current_message, current_step, total_steps)
123+
124+
price = get_price_for_route(
125+
travel_action_id=travel_action_id,
126+
origin_details=origin_details,
127+
destination_details=destination_details,
124128
date=date,
125129
has_vc66=has_vc66,
126-
get_only_first=False,
127130
access_token=access_token,
128131
)
129-
if not connection_ids:
130-
logger.warning("Could not process connections.")
131-
current_message = _("Failed to process connections")
132-
yield render(current_message)
133-
return
134132

135-
current_step += 1
136-
current_message = _("Retrieving price")
137-
yield render(current_message, current_step, total_steps)
138-
price = get_price_for_connection(connection_ids, access_token=access_token)
139133
if not price:
140134
logger.warning("Could not retrieve price.")
141135
current_message = _("Failed to retrieve price")

tests/test_util.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import re
3+
from datetime import datetime, timedelta
34

45
from typing import TYPE_CHECKING
56

@@ -45,3 +46,24 @@ def test_get_price(self) -> None:
4546
assert price
4647
assert price > 5
4748
assert price < 50
49+
50+
def test_get_price_for_route_with_few_prices(self) -> None:
51+
"""Only few connections for this route have prices"""
52+
origin = "Moosburg in Ktn Ortsmitte"
53+
destination = "Klagenfurt Heuplatz "
54+
55+
today = datetime.now()
56+
days_ahead = 0 - today.weekday()
57+
if days_ahead <= 0:
58+
days_ahead += 7
59+
next_monday = (today + timedelta(days=days_ahead)).replace(
60+
hour=8, minute=0, second=0, microsecond=0
61+
)
62+
63+
price = get_price(
64+
origin, destination, date=next_monday, has_vc66=True, take_median=True
65+
)
66+
67+
assert price
68+
assert price > 0
69+
assert price < 50

util/oebb.py

Lines changed: 105 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33
from datetime import datetime, timedelta
44
from statistics import median
5-
from typing import Optional, Dict, List, Union
5+
from typing import Optional, Dict, List, Union, Tuple
66

77
import requests
88
from sentry_sdk import add_breadcrumb
@@ -23,6 +23,8 @@
2323

2424
logger = logging.getLogger(__name__)
2525

26+
MAX_BATCHES = 5
27+
2628

2729
def init_user_data(
2830
access_token: str,
@@ -275,6 +277,32 @@ def get_connection_ids(
275277
access_token: Optional[str] = None,
276278
host: str = CONFIG["host"],
277279
) -> Optional[Union[str, List[str]]]:
280+
result = _get_connection_ids_with_next_date(
281+
travel_action_id=travel_action_id,
282+
origin_station_details=origin_station_details,
283+
destination_station_details=destination_station_details,
284+
date=date,
285+
has_vc66=has_vc66,
286+
get_only_first=get_only_first,
287+
access_token=access_token,
288+
host=host,
289+
)
290+
if result is None:
291+
return None
292+
connection_ids, _ = result
293+
return connection_ids
294+
295+
296+
def _get_connection_ids_with_next_date(
297+
travel_action_id: str,
298+
origin_station_details: Dict[str, Union[str, int, None]],
299+
destination_station_details: Dict[str, Union[str, int, None]],
300+
date: Optional[datetime] = None,
301+
has_vc66: bool = False,
302+
get_only_first: bool = True,
303+
access_token: Optional[str] = None,
304+
host: str = CONFIG["host"],
305+
) -> Optional[Tuple[Union[str, List[str]], Optional[datetime]]]:
278306
url = host + API_PATHS["timetable"]
279307
if not date:
280308
date = datetime.utcnow()
@@ -394,7 +422,7 @@ def get_connection_ids(
394422
except requests.exceptions.JSONDecodeError:
395423
logger.error(
396424
f"Failed to decode JSON response from timetable. Text: {r.text[:500]}"
397-
) # Log snippet of text
425+
)
398426
return None
399427

400428
connections = response_json.get("connections")
@@ -404,24 +432,41 @@ def get_connection_ids(
404432
)
405433
return None
406434

435+
next_date: Optional[datetime] = None
436+
last_connection = connections[-1]
437+
last_departure_str = last_connection.get("from", {}).get(
438+
"departure"
439+
) or last_connection.get("departure")
440+
if last_departure_str:
441+
try:
442+
last_departure_str_trimmed = last_departure_str[:19]
443+
last_departure = datetime.strptime(
444+
last_departure_str_trimmed, "%Y-%m-%dT%H:%M:%S"
445+
)
446+
next_date = last_departure + timedelta(minutes=1)
447+
except (ValueError, TypeError) as e:
448+
logger.warning(
449+
f"Could not parse last departure time '{last_departure_str}': {e}"
450+
)
451+
407452
if get_only_first:
408453
connection_id = connections[0].get("id")
409454
if not connection_id:
410455
logger.error(
411456
f"First connection is missing an ID. Connection data: {connections[0]}"
412457
)
413458
return None
414-
return connection_id
459+
return connection_id, next_date
415460

416461
connection_ids = [
417462
connection.get("id") for connection in connections if connection.get("id")
418463
]
419-
if not connection_ids: # If all connections were missing IDs
464+
if not connection_ids:
420465
logger.error(
421466
f"All connections were missing IDs. Connections data: {connections}"
422467
)
423468
return None
424-
return connection_ids
469+
return connection_ids, next_date
425470

426471

427472
def get_price_for_connection(
@@ -503,6 +548,57 @@ def get_price_for_connection(
503548
return price
504549

505550

551+
def get_price_for_route(
552+
travel_action_id: str,
553+
origin_details: Dict[str, Union[str, int, None]],
554+
destination_details: Dict[str, Union[str, int, None]],
555+
date: Optional[datetime] = None,
556+
has_vc66: bool = False,
557+
access_token: Optional[str] = None,
558+
host: str = CONFIG["host"],
559+
) -> Optional[float]:
560+
if not date:
561+
date = datetime.utcnow()
562+
563+
current_date = date
564+
for batch_num in range(MAX_BATCHES):
565+
logger.info(
566+
f"Fetching connections batch {batch_num + 1}/{MAX_BATCHES} starting from {current_date}."
567+
)
568+
result = _get_connection_ids_with_next_date(
569+
travel_action_id=travel_action_id,
570+
origin_station_details=origin_details,
571+
destination_station_details=destination_details,
572+
date=current_date,
573+
has_vc66=has_vc66,
574+
get_only_first=False,
575+
access_token=access_token,
576+
host=host,
577+
)
578+
if result is None:
579+
logger.warning(f"Could not get connections in batch {batch_num + 1}.")
580+
return None
581+
582+
connection_ids, next_date = result
583+
584+
price = get_price_for_connection(
585+
connection_ids, access_token=access_token, host=host
586+
)
587+
if price is not None:
588+
return price
589+
590+
logger.info(f"No price found in batch {batch_num + 1}, trying next batch.")
591+
592+
if next_date is None:
593+
logger.warning("Could not determine next batch date, stopping.")
594+
return None
595+
596+
current_date = next_date
597+
598+
logger.warning(f"Could not find a price after {MAX_BATCHES} batches.")
599+
return None
600+
601+
506602
def get_price(
507603
origin: str,
508604
destination: str,
@@ -543,19 +639,13 @@ def get_price(
543639
logger.warning("Could not get travel action ID.")
544640
return None
545641

546-
connection_ids = get_connection_ids(
547-
travel_action_id,
548-
origin_details,
549-
destination_details,
642+
price = get_price_for_route(
643+
travel_action_id=travel_action_id,
644+
origin_details=origin_details,
645+
destination_details=destination_details,
550646
date=date,
551647
has_vc66=has_vc66,
552-
get_only_first=(not take_median),
553648
access_token=access_token,
554649
host=CONFIG["host"],
555650
)
556-
if not connection_ids:
557-
logger.warning("Could not get connection ID.")
558-
return None
559-
560-
price = get_price_for_connection(connection_ids, access_token=access_token)
561651
return price

0 commit comments

Comments
 (0)