88from shapely .strtree import STRtree
99from shapely .geometry import Point
1010from collections import namedtuple as nt
11+ from .locator import Locator
1112
1213concave_hull_ratio = 0.3
1314
@@ -17,10 +18,10 @@ class Mapsegmentation:
1718 A class that handels multiple intersections on a single map.
1819 """
1920
20- def __init__ (self , map , locator , lane_buffer = None , intersection_overlap_buffer = None ):
21- self .map = map
22- self .locator = locator
23- self .lanes = locator . all_lanes
21+ def __init__ (self , recording , lane_buffer = None , intersection_overlap_buffer = None ):
22+ self .map = recording . map
23+ self .locator = Locator . from_map ( recording . map )
24+ self .lanes = recording . map . lanes
2425 self .intersections = []
2526 self .lane_dict = {}
2627 self .lane_successors_dict = {}
@@ -37,7 +38,7 @@ def __init__(self, map, locator, lane_buffer=None, intersection_overlap_buffer=N
3738 self .segments = []
3839
3940 segment_name = nt ("SegmentName" , ["lane_id" , "segment_idx" , "segment" ])
40- for lane in self .locator . all_lanes :
41+ for lane in self .lanes . values () :
4142 self .lane_segment_dict [lane .idx .lane_id ] = segment_name (lane .idx .lane_id , None , None )
4243
4344 def init_intersections (self ):
@@ -53,7 +54,7 @@ def init_intersections(self):
5354 self .parallel_lane_dict = self .create_parallel_lane_dict ()
5455 self .get_intersecting_lanes ()
5556 self .graph_intersection_detection ()
56- self .G = add_lanexy_to_graph (self .G , self .locator )
57+ self .G = add_lanexy_to_graph (self .G , self .lanes )
5758 self .set_intersection_idx ()
5859
5960 if self .do_combine_intersections :
@@ -68,6 +69,7 @@ def init_intersections(self):
6869 self .set_lane_intersection_relation ()
6970 self .check_if_all_lanes_are_on_segment ()
7071 self .update_segment_ids ()
72+ self .create_lane_segment_dict ()
7173 self .update_road_ids ()
7274
7375 # from pathlib import Path
@@ -78,12 +80,10 @@ def update_road_ids(self):
7880 """
7981 Updates the road_ids of the lane to the segment ID
8082 """
81- for i , lane in enumerate (self .locator . all_lanes ):
83+ for i , lane in enumerate (self .lanes . values () ):
8284 lane_id = lane .idx .lane_id
8385 if lane_id in self .lane_segment_dict and self .lane_segment_dict [lane_id ].segment is not None :
84- self .locator .all_lanes [i ].idx = OsiLaneId (
85- road_id = self .lane_segment_dict [lane_id ].segment_idx , lane_id = lane_id
86- )
86+ lane .idx = OsiLaneId (road_id = self .lane_segment_dict [lane_id ].segment .idx , lane_id = lane_id )
8787 else :
8888 # print(f"Lane {lane_id} is not part of an intersection, so it has no segment ID")
8989 pass
@@ -100,18 +100,18 @@ def create_parallel_lane_dict(self):
100100 Returns:
101101 dict: A dictionary mapping each lane's lane_id to a list of parallel lane ids.
102102 """
103- lane_dict = {lane .idx .lane_id : [] for lane in self .locator . all_lanes }
103+ lane_dict = {lane .idx .lane_id : [] for lane in self .lanes . values () }
104104
105105 # Precompute lane directions for faster comparisons
106106 lane_directions = {}
107- for lane in self .locator . all_lanes :
107+ for lane in self .lanes . values () :
108108 coords = np .array (lane .centerline .coords )
109109 direction = coords [- 1 ] - coords [0 ]
110110 lane_directions [lane .idx .lane_id ] = direction / np .linalg .norm (direction )
111111
112112 # Use spatial index to find potential parallel lanes
113- lane_centerlines = [lane .centerline for lane in self .locator . all_lanes ]
114- lane_ids = [lane .idx .lane_id for lane in self .locator . all_lanes ]
113+ lane_centerlines = [lane .centerline for lane in self .lanes . values () ]
114+ lane_ids = [lane .idx .lane_id for lane in self .lanes . values () ]
115115
116116 # Only build tree if there are lanes
117117 if lane_centerlines :
@@ -120,7 +120,7 @@ def create_parallel_lane_dict(self):
120120 tree = STRtree (buffered_centerlines )
121121
122122 # For each lane, find potential parallel lanes
123- for i , lane in enumerate (self .locator . all_lanes ):
123+ for i , lane in enumerate (self .lanes . values () ):
124124 lane_id = lane .idx .lane_id
125125 buffer_geom = buffered_centerlines [i ]
126126
@@ -155,7 +155,7 @@ def check_if_all_lanes_are_on_segment(self):
155155 Returns:
156156 bool: True if all lanes are on a segment, False otherwise.
157157 """
158- for lane in self .locator . all_lanes :
158+ for lane in self .lanes . values () :
159159 lane_id = lane .idx .lane_id
160160 if lane_id not in self .lane_segment_dict or self .lane_segment_dict [lane_id ].segment is None :
161161 print (f"Lane { lane_id } is not on a segment" )
@@ -192,9 +192,9 @@ def trajectory_segment_detection(self, trajectory):
192192 buffered_polygons [segment .idx ] = segment .polygon .buffer (buffer )
193193
194194 for i , (frame , x , y , _ , _ ) in enumerate (trajectory ):
195- point = Point (x , y )
195+ point = Point (x , y ). buffer ( 2 )
196196 for segment_id , buffered_poly in buffered_polygons .items ():
197- if buffered_poly .contains (point ):
197+ if buffered_poly .intersects (point ):
198198 trajectory [i , 4 ] = segment_id
199199 break
200200
@@ -221,7 +221,7 @@ def create_lane_dict(self):
221221 Returns:
222222 lane_dict (dict): A dictionary mapping each lane's lane_id to the lane object.
223223 """
224- self .lane_dict = {lane .idx .lane_id : lane for lane in self .locator . all_lanes }
224+ self .lane_dict = {lane .idx .lane_id : lane for lane in self .lanes . values () }
225225 return self .lane_dict
226226
227227 def get_lane_successors_and_predecessors (self ):
@@ -240,7 +240,7 @@ def get_lane_successors_and_predecessors(self):
240240 lane_successors = {}
241241 lane_predecessors = {}
242242
243- for lane in self .locator . all_lanes :
243+ for lane in self .lanes . values () :
244244 # Use lane_id as keys, convert successor_ids to lane_ids if needed
245245 lane_successors [lane .idx .lane_id ] = [
246246 succ_id .lane_id if hasattr (succ_id , "lane_id" ) else succ_id for succ_id in lane .successor_ids
@@ -267,7 +267,7 @@ def get_intersecting_lanes(self, buffer: float = None):
267267 buffer = self .lane_buffer
268268
269269 # Precompute buffered geometries
270- buffered_lanes = {lane .idx .lane_id : lane .centerline .buffer (buffer ) for lane in self .lanes }
270+ buffered_lanes = {lane .idx .lane_id : lane .centerline .buffer (buffer ) for lane in self .lanes . values () }
271271
272272 # Create spatial index
273273 buffered_geoms = list (buffered_lanes .values ())
@@ -278,7 +278,7 @@ def get_intersecting_lanes(self, buffer: float = None):
278278 tree = None
279279
280280 intersecting_lanes = {}
281- for i , lane in enumerate (self .lanes ):
281+ for i , lane in enumerate (self .lanes . values () ):
282282 lane_id = lane .idx .lane_id
283283 if tree is None :
284284 intersecting_lanes [lane_id ] = []
@@ -454,7 +454,7 @@ def add_non_intersecting_lanes_to_intersection(self):
454454 """
455455 for intersection in self .intersections :
456456 intersection .update_polygon ()
457- for lane in self .locator . all_lanes :
457+ for lane in self .lanes . values () :
458458 lane_id = lane .idx .lane_id
459459 if (
460460 lane_id not in intersection .lane_ids
@@ -470,23 +470,25 @@ def create_lane_segment_dict(self):
470470 segment_name = nt ("SegmentName" , ["lane_id" , "segment_idx" , "segment" ])
471471 # Combine self.intersections and self.isolated_connections into the lane_segment_dict
472472 segent_list = self .intersections + self .isolated_connections
473+ lane_segment_dict = {lane_id : segment_name (lane_id , None , None ) for lane_id in self .lane_dict .keys ()}
473474
474475 for segment in segent_list :
475476 for lane in segment .lanes :
476477 lane_id = lane .idx .lane_id
477478 try :
478- if self . lane_segment_dict [lane_id ].segment is None :
479+ if lane_segment_dict [lane_id ].segment is None :
479480 # The lane is not connected to any segment yet, so we can add it
480- self . lane_segment_dict [lane_id ] = segment_name (lane_id , segment .idx , segment )
481+ lane_segment_dict [lane_id ] = segment_name (lane_id , segment .idx , segment )
481482 else :
482483 # If the lane is already in the dictionary, check if the segment is the same
483- if self . lane_segment_dict [lane_id ].segment_idx != segment .idx :
484+ if lane_segment_dict [lane_id ].segment_idx != segment .idx :
484485 raise ValueError (
485- f"Lane { lane_id } is already in segment { self . lane_segment_dict [lane_id ].segment_idx } but trying to add it to segment { segment .idx } "
486+ f"Lane { lane_id } is already in segment { lane_segment_dict [lane_id ].segment_idx } but trying to add it to segment { segment .idx } "
486487 )
487488 except ValueError as e :
488489 print (f"Error: { e } " )
489490 continue
491+ self .lane_segment_dict = lane_segment_dict
490492
491493 def create_non_intersecting_lane_graph (self ):
492494 """Create a graph with each lane which is not part of a intersection as a node and the edges are the successors and predecessors of the lanes.
@@ -496,7 +498,7 @@ def create_non_intersecting_lane_graph(self):
496498 G (networkx.Graph): A graph with each lane as a node and the edges are the successors and predecessors of the lanes.
497499 """
498500 G = nx .Graph ()
499- for lane in self .locator . all_lanes :
501+ for lane in self .lanes . values () :
500502 lane_id = lane .idx .lane_id
501503 if lane_id not in self .lane_segment_dict or self .lane_segment_dict [lane_id ].segment is None :
502504 G .add_node (lane_id )
@@ -572,10 +574,16 @@ def combine_isolated_connections(self, isolated_connections):
572574 if (
573575 i != j
574576 and connection .intersection_idxs == connection2 .intersection_idxs
575- and len (connection .intersection_idxs ) > 0
577+ and len (connection .intersection_idxs ) > 1
578+ ):
579+ combined_connections .append ([i , j ])
580+ elif (
581+ i != j
582+ and connection .intersection_idxs == connection2 .intersection_idxs
583+ and len (connection .intersection_idxs ) == 1
584+ and connection .polygon .distance (connection2 .polygon ) < 5
576585 ):
577586 combined_connections .append ([i , j ])
578- # print(f"Connection {i} and {j} can be combined")
579587 final_combined = self .find_resulting_intersections (combined_connections )
580588 # print(f"Final combined: {final_combined}")
581589 new_connections = []
@@ -606,21 +614,21 @@ def set_lane_intersection_relation(self):
606614 Sets the attribute lane.approching true if the lane is connecting to an intersection.
607615 Setss the attribute lane.on_intersection true if the lane is part of an intersection.
608616 """
609- for lane in self .locator . all_lanes :
610- lane .on_intersection = False
611- lane .is_approaching = False
617+ for lane in self .lanes . values () :
618+ self . lanes [ lane . idx ] .on_intersection = False
619+ self . lanes [ lane . idx ] .is_approaching = False
612620
613621 for intersection in self .intersections :
614622 for lane in intersection .lanes :
615- lane .on_intersection = True
616- lane .is_approaching = False
623+ self . lanes [ lane . idx ] .on_intersection = True
624+ self . lanes [ lane . idx ] .is_approaching = False
617625
618626 for predecessor_id in lane .predecessor_ids :
619627 # Convert to lane_id if it's a namedtuple
620628 pred_id = predecessor_id .lane_id if hasattr (predecessor_id , "lane_id" ) else predecessor_id
621629 predecessor = self .lane_dict [pred_id ]
622630 if predecessor .on_intersection is not True :
623- predecessor .is_approaching = True
631+ self . lanes [ predecessor . idx ] .is_approaching = True
624632
625633 def plot (self , output_plot = None , trajectory = None , plot_lane_ids = False ):
626634 """
@@ -634,7 +642,7 @@ def plot(self, output_plot=None, trajectory=None, plot_lane_ids=False):
634642 fig , ax = plt .subplots (1 , 1 )
635643 ax .set_aspect (1 )
636644
637- for lane in self .locator . all_lanes :
645+ for lane in self .lanes . values () :
638646 c = "black" if not lane .type == betterosi .LaneClassificationType .TYPE_INTERSECTION else "green"
639647 ax .plot (* lane .centerline .xy , color = c , alpha = 0.3 , zorder = - 10 )
640648 if plot_lane_ids :
0 commit comments