Skip to content

Commit ea1d5a4

Browse files
committed
Merge branch 'Fix_lane_relations' into '3-identify-map-segments-intersection-roundabout-in-lanes-e-g-wod-motion'
Fix lane relations See merge request fb-fi/data/omega-prime!11
2 parents f4d57ad + 3371c54 commit ea1d5a4

4 files changed

Lines changed: 94 additions & 45 deletions

File tree

intersection.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
# mapsegment.plot_intersections(output_path)
2020
print(sys.getrecursionlimit())
2121

22-
file = Path("/scenario-center-playground/data/training_20s/training_20s_863518a3eb519031.mcap")
22+
file = Path(
23+
"/scenario-center-playground/data/training_20s/training_20s.tfrecord-00106-of-01000/training_20s.tfrecord-00106-of-01000_8450f9adb20760a5.mcap"
24+
)
2325
output_path = Path("/scenario-center-playground/scenarios/") / file.stem
2426
output_path.mkdir(parents=True, exist_ok=True)
25-
r = omega_prime.Recording.from_file(filepath=file, split_lanes=True, split_lanes_lenght=5)
27+
r = omega_prime.Recording.from_file(filepath=file, compute_polygons=True, split_lanes=True, split_lanes_lenght=10)
2628
r.create_mapsegments()
2729
mapsegment = r.mapsegment
2830
locator = omega_prime.Locator.from_map(r.map)

omega_prime/intersection_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def find_misclassified_intersection_lanes(lane_dict, all_lanes):
2323
return np.array(misclassified_lanes)
2424

2525

26-
def add_lanexy_to_graph(G, locator):
26+
def add_lanexy_to_graph(G, lanes):
2727
"""
2828
Adds lane coordinates to the graph as node attributes.
2929
@@ -34,7 +34,7 @@ def add_lanexy_to_graph(G, locator):
3434
Returns:
3535
networkx.Graph: The updated graph with lane coordinates as node attributes.
3636
"""
37-
for lane in locator.all_lanes:
37+
for lane in lanes.values():
3838
if lane.idx.lane_id in G.nodes:
3939
G.nodes[lane.idx.lane_id]["x"] = shapely.centroid(lane.centerline).x
4040
G.nodes[lane.idx.lane_id]["y"] = shapely.centroid(lane.centerline).y

omega_prime/mapsemgents.py

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from shapely.strtree import STRtree
99
from shapely.geometry import Point
1010
from collections import namedtuple as nt
11+
from .locator import Locator
1112

1213
concave_hull_ratio = 0.3
1314

@@ -17,10 +18,10 @@ class Mapsegmentation:
1718
A class that handels multiple intersections on a single map.
1819
"""
1920

20-
def __init__(self, map, locator, lane_buffer=None, intersection_overlap_buffer=None):
21-
self.map = map
22-
self.locator = locator
23-
self.lanes = locator.all_lanes
21+
def __init__(self, recording, lane_buffer=None, intersection_overlap_buffer=None):
22+
self.map = recording.map
23+
self.locator = Locator.from_map(recording.map)
24+
self.lanes = recording.map.lanes
2425
self.intersections = []
2526
self.lane_dict = {}
2627
self.lane_successors_dict = {}
@@ -37,7 +38,7 @@ def __init__(self, map, locator, lane_buffer=None, intersection_overlap_buffer=N
3738
self.segments = []
3839

3940
segment_name = nt("SegmentName", ["lane_id", "segment_idx", "segment"])
40-
for lane in self.locator.all_lanes:
41+
for lane in self.lanes.values():
4142
self.lane_segment_dict[lane.idx.lane_id] = segment_name(lane.idx.lane_id, None, None)
4243

4344
def init_intersections(self):
@@ -53,7 +54,7 @@ def init_intersections(self):
5354
self.parallel_lane_dict = self.create_parallel_lane_dict()
5455
self.get_intersecting_lanes()
5556
self.graph_intersection_detection()
56-
self.G = add_lanexy_to_graph(self.G, self.locator)
57+
self.G = add_lanexy_to_graph(self.G, self.lanes)
5758
self.set_intersection_idx()
5859

5960
if self.do_combine_intersections:
@@ -68,6 +69,7 @@ def init_intersections(self):
6869
self.set_lane_intersection_relation()
6970
self.check_if_all_lanes_are_on_segment()
7071
self.update_segment_ids()
72+
self.create_lane_segment_dict()
7173
self.update_road_ids()
7274

7375
# from pathlib import Path
@@ -78,12 +80,10 @@ def update_road_ids(self):
7880
"""
7981
Updates the road_ids of the lane to the segment ID
8082
"""
81-
for i, lane in enumerate(self.locator.all_lanes):
83+
for i, lane in enumerate(self.lanes.values()):
8284
lane_id = lane.idx.lane_id
8385
if lane_id in self.lane_segment_dict and self.lane_segment_dict[lane_id].segment is not None:
84-
self.locator.all_lanes[i].idx = OsiLaneId(
85-
road_id=self.lane_segment_dict[lane_id].segment_idx, lane_id=lane_id
86-
)
86+
lane.idx = OsiLaneId(road_id=self.lane_segment_dict[lane_id].segment.idx, lane_id=lane_id)
8787
else:
8888
# print(f"Lane {lane_id} is not part of an intersection, so it has no segment ID")
8989
pass
@@ -100,18 +100,18 @@ def create_parallel_lane_dict(self):
100100
Returns:
101101
dict: A dictionary mapping each lane's lane_id to a list of parallel lane ids.
102102
"""
103-
lane_dict = {lane.idx.lane_id: [] for lane in self.locator.all_lanes}
103+
lane_dict = {lane.idx.lane_id: [] for lane in self.lanes.values()}
104104

105105
# Precompute lane directions for faster comparisons
106106
lane_directions = {}
107-
for lane in self.locator.all_lanes:
107+
for lane in self.lanes.values():
108108
coords = np.array(lane.centerline.coords)
109109
direction = coords[-1] - coords[0]
110110
lane_directions[lane.idx.lane_id] = direction / np.linalg.norm(direction)
111111

112112
# Use spatial index to find potential parallel lanes
113-
lane_centerlines = [lane.centerline for lane in self.locator.all_lanes]
114-
lane_ids = [lane.idx.lane_id for lane in self.locator.all_lanes]
113+
lane_centerlines = [lane.centerline for lane in self.lanes.values()]
114+
lane_ids = [lane.idx.lane_id for lane in self.lanes.values()]
115115

116116
# Only build tree if there are lanes
117117
if lane_centerlines:
@@ -120,7 +120,7 @@ def create_parallel_lane_dict(self):
120120
tree = STRtree(buffered_centerlines)
121121

122122
# For each lane, find potential parallel lanes
123-
for i, lane in enumerate(self.locator.all_lanes):
123+
for i, lane in enumerate(self.lanes.values()):
124124
lane_id = lane.idx.lane_id
125125
buffer_geom = buffered_centerlines[i]
126126

@@ -155,7 +155,7 @@ def check_if_all_lanes_are_on_segment(self):
155155
Returns:
156156
bool: True if all lanes are on a segment, False otherwise.
157157
"""
158-
for lane in self.locator.all_lanes:
158+
for lane in self.lanes.values():
159159
lane_id = lane.idx.lane_id
160160
if lane_id not in self.lane_segment_dict or self.lane_segment_dict[lane_id].segment is None:
161161
print(f"Lane {lane_id} is not on a segment")
@@ -192,9 +192,9 @@ def trajectory_segment_detection(self, trajectory):
192192
buffered_polygons[segment.idx] = segment.polygon.buffer(buffer)
193193

194194
for i, (frame, x, y, _, _) in enumerate(trajectory):
195-
point = Point(x, y)
195+
point = Point(x, y).buffer(2)
196196
for segment_id, buffered_poly in buffered_polygons.items():
197-
if buffered_poly.contains(point):
197+
if buffered_poly.intersects(point):
198198
trajectory[i, 4] = segment_id
199199
break
200200

@@ -221,7 +221,7 @@ def create_lane_dict(self):
221221
Returns:
222222
lane_dict (dict): A dictionary mapping each lane's lane_id to the lane object.
223223
"""
224-
self.lane_dict = {lane.idx.lane_id: lane for lane in self.locator.all_lanes}
224+
self.lane_dict = {lane.idx.lane_id: lane for lane in self.lanes.values()}
225225
return self.lane_dict
226226

227227
def get_lane_successors_and_predecessors(self):
@@ -240,7 +240,7 @@ def get_lane_successors_and_predecessors(self):
240240
lane_successors = {}
241241
lane_predecessors = {}
242242

243-
for lane in self.locator.all_lanes:
243+
for lane in self.lanes.values():
244244
# Use lane_id as keys, convert successor_ids to lane_ids if needed
245245
lane_successors[lane.idx.lane_id] = [
246246
succ_id.lane_id if hasattr(succ_id, "lane_id") else succ_id for succ_id in lane.successor_ids
@@ -267,7 +267,7 @@ def get_intersecting_lanes(self, buffer: float = None):
267267
buffer = self.lane_buffer
268268

269269
# Precompute buffered geometries
270-
buffered_lanes = {lane.idx.lane_id: lane.centerline.buffer(buffer) for lane in self.lanes}
270+
buffered_lanes = {lane.idx.lane_id: lane.centerline.buffer(buffer) for lane in self.lanes.values()}
271271

272272
# Create spatial index
273273
buffered_geoms = list(buffered_lanes.values())
@@ -278,7 +278,7 @@ def get_intersecting_lanes(self, buffer: float = None):
278278
tree = None
279279

280280
intersecting_lanes = {}
281-
for i, lane in enumerate(self.lanes):
281+
for i, lane in enumerate(self.lanes.values()):
282282
lane_id = lane.idx.lane_id
283283
if tree is None:
284284
intersecting_lanes[lane_id] = []
@@ -454,7 +454,7 @@ def add_non_intersecting_lanes_to_intersection(self):
454454
"""
455455
for intersection in self.intersections:
456456
intersection.update_polygon()
457-
for lane in self.locator.all_lanes:
457+
for lane in self.lanes.values():
458458
lane_id = lane.idx.lane_id
459459
if (
460460
lane_id not in intersection.lane_ids
@@ -470,23 +470,25 @@ def create_lane_segment_dict(self):
470470
segment_name = nt("SegmentName", ["lane_id", "segment_idx", "segment"])
471471
# Combine self.intersections and self.isolated_connections into the lane_segment_dict
472472
segent_list = self.intersections + self.isolated_connections
473+
lane_segment_dict = {lane_id: segment_name(lane_id, None, None) for lane_id in self.lane_dict.keys()}
473474

474475
for segment in segent_list:
475476
for lane in segment.lanes:
476477
lane_id = lane.idx.lane_id
477478
try:
478-
if self.lane_segment_dict[lane_id].segment is None:
479+
if lane_segment_dict[lane_id].segment is None:
479480
# The lane is not connected to any segment yet, so we can add it
480-
self.lane_segment_dict[lane_id] = segment_name(lane_id, segment.idx, segment)
481+
lane_segment_dict[lane_id] = segment_name(lane_id, segment.idx, segment)
481482
else:
482483
# If the lane is already in the dictionary, check if the segment is the same
483-
if self.lane_segment_dict[lane_id].segment_idx != segment.idx:
484+
if lane_segment_dict[lane_id].segment_idx != segment.idx:
484485
raise ValueError(
485-
f"Lane {lane_id} is already in segment {self.lane_segment_dict[lane_id].segment_idx} but trying to add it to segment {segment.idx}"
486+
f"Lane {lane_id} is already in segment {lane_segment_dict[lane_id].segment_idx} but trying to add it to segment {segment.idx}"
486487
)
487488
except ValueError as e:
488489
print(f"Error: {e}")
489490
continue
491+
self.lane_segment_dict = lane_segment_dict
490492

491493
def create_non_intersecting_lane_graph(self):
492494
"""Create a graph with each lane which is not part of a intersection as a node and the edges are the successors and predecessors of the lanes.
@@ -496,7 +498,7 @@ def create_non_intersecting_lane_graph(self):
496498
G (networkx.Graph): A graph with each lane as a node and the edges are the successors and predecessors of the lanes.
497499
"""
498500
G = nx.Graph()
499-
for lane in self.locator.all_lanes:
501+
for lane in self.lanes.values():
500502
lane_id = lane.idx.lane_id
501503
if lane_id not in self.lane_segment_dict or self.lane_segment_dict[lane_id].segment is None:
502504
G.add_node(lane_id)
@@ -572,10 +574,16 @@ def combine_isolated_connections(self, isolated_connections):
572574
if (
573575
i != j
574576
and connection.intersection_idxs == connection2.intersection_idxs
575-
and len(connection.intersection_idxs) > 0
577+
and len(connection.intersection_idxs) > 1
578+
):
579+
combined_connections.append([i, j])
580+
elif (
581+
i != j
582+
and connection.intersection_idxs == connection2.intersection_idxs
583+
and len(connection.intersection_idxs) == 1
584+
and connection.polygon.distance(connection2.polygon) < 5
576585
):
577586
combined_connections.append([i, j])
578-
# print(f"Connection {i} and {j} can be combined")
579587
final_combined = self.find_resulting_intersections(combined_connections)
580588
# print(f"Final combined: {final_combined}")
581589
new_connections = []
@@ -606,21 +614,21 @@ def set_lane_intersection_relation(self):
606614
Sets the attribute lane.approching true if the lane is connecting to an intersection.
607615
Setss the attribute lane.on_intersection true if the lane is part of an intersection.
608616
"""
609-
for lane in self.locator.all_lanes:
610-
lane.on_intersection = False
611-
lane.is_approaching = False
617+
for lane in self.lanes.values():
618+
self.lanes[lane.idx].on_intersection = False
619+
self.lanes[lane.idx].is_approaching = False
612620

613621
for intersection in self.intersections:
614622
for lane in intersection.lanes:
615-
lane.on_intersection = True
616-
lane.is_approaching = False
623+
self.lanes[lane.idx].on_intersection = True
624+
self.lanes[lane.idx].is_approaching = False
617625

618626
for predecessor_id in lane.predecessor_ids:
619627
# Convert to lane_id if it's a namedtuple
620628
pred_id = predecessor_id.lane_id if hasattr(predecessor_id, "lane_id") else predecessor_id
621629
predecessor = self.lane_dict[pred_id]
622630
if predecessor.on_intersection is not True:
623-
predecessor.is_approaching = True
631+
self.lanes[predecessor.idx].is_approaching = True
624632

625633
def plot(self, output_plot=None, trajectory=None, plot_lane_ids=False):
626634
"""
@@ -634,7 +642,7 @@ def plot(self, output_plot=None, trajectory=None, plot_lane_ids=False):
634642
fig, ax = plt.subplots(1, 1)
635643
ax.set_aspect(1)
636644

637-
for lane in self.locator.all_lanes:
645+
for lane in self.lanes.values():
638646
c = "black" if not lane.type == betterosi.LaneClassificationType.TYPE_INTERSECTION else "green"
639647
ax.plot(*lane.centerline.xy, color=c, alpha=0.3, zorder=-10)
640648
if plot_lane_ids:

omega_prime/recording.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import altair as alt
2424
import polars as pl
2525
import polars_st as st
26-
from .locator import Locator
2726
from .mapsemgents import Mapsegmentation
2827

2928
pi_valued = pa.Check.between(-np.pi, np.pi)
@@ -617,7 +616,7 @@ def from_file(
617616
lanes_new[new_lane.idx.lane_id] = new_lane
618617
idx_count += 1
619618

620-
# Update references in other lanes' predecessors/successors
619+
# Update references in other lanes' predecessors/successors
621620
for other_lane in lanes_or.values():
622621
if lane.idx in other_lane.successor_ids:
623622
# Replace reference to original lane with first segment
@@ -629,6 +628,46 @@ def from_file(
629628
other_lane.predecessor_ids[idx] = segment_lanes[-1].idx
630629

631630
# Replace original lanes with segmented lanes
631+
632+
# Do a check for the presessor and successor: Check if the distance between the centerlines is greater than the max_len --> if yes, then remove the connection
633+
for lane in lanes_new.values():
634+
if lane.predecessor_ids:
635+
for pre in lane.predecessor_ids:
636+
pre_to_remove = []
637+
if lanes_new[pre.lane_id].centerline.distance(lane.centerline) > max_len:
638+
print(
639+
f"Removing connection between {pre} and {lane.idx} due to distance > {max_len}"
640+
)
641+
pre_to_remove.append(pre)
642+
try:
643+
lanes_new[pre.lane_id].successor_ids.remove(lane.idx)
644+
except ValueError:
645+
pass # If the successor is not in the list, ignore
646+
647+
for pre in pre_to_remove:
648+
try:
649+
lanes_new[lane.idx.lane_id].predecessor_ids.remove(pre)
650+
except ValueError:
651+
pass # If the predecessor is not in the list, ignore
652+
if lane.successor_ids:
653+
for suc in lane.successor_ids:
654+
suc_to_remove = []
655+
if lanes_new[suc.lane_id].centerline.distance(lane.centerline) > max_len:
656+
print(
657+
f"Removing connection between {lane.idx} and {suc} due to distance > {max_len}"
658+
)
659+
suc_to_remove.append(suc)
660+
try:
661+
lanes_new[suc.lane_id].predecessor_ids.remove(lane.idx)
662+
except ValueError:
663+
pass # If the predecessor is not in the list, ignore
664+
665+
for suc in suc_to_remove:
666+
try:
667+
lanes_new[lane.idx.lane_id].successor_ids.remove(suc)
668+
except ValueError:
669+
pass
670+
632671
r.map.lanes = {lane.idx: lane for lane in lanes_new.values()}
633672
for lane in r.map.lanes.values():
634673
lane._map = r.map
@@ -895,6 +934,6 @@ def plot_altair(
895934
return view.add_params(op_var)
896935

897936
def create_mapsegments(self):
898-
locator = Locator.from_map(self.map)
899-
self.mapsegment = Mapsegmentation(self.map, locator=locator)
937+
self.mapsegment = Mapsegmentation(self)
900938
self.mapsegment.init_intersections()
939+
print("test")

0 commit comments

Comments
 (0)