@@ -355,6 +355,7 @@ def __init__(
355355 host_vehicle_idx : int | None = None ,
356356 validate = False ,
357357 compute_polygons = False ,
358+ traffic_light_states : dict | None = None ,
358359 ):
359360 if not isinstance (df , pl .DataFrame ):
360361 df = pl .DataFrame (df , schema_overrides = polars_schema )
@@ -384,6 +385,8 @@ def __init__(
384385 .alias ("acc" )
385386 )
386387 self .projections = projections if projections is not None else []
388+ self .traffic_light_states = traffic_light_states if traffic_light_states is not None else {}
389+
387390 self ._df = df
388391 self .map = map
389392 self ._moving_objects = None
@@ -438,11 +441,14 @@ def to_osi_gts(self) -> list[betterosi.GroundTruth]:
438441 if self .map is not None and isinstance (self .map , MapOsi | MapOsiCenterline ):
439442 gt .lane_boundary = [b ._osi for b in self .map .lane_boundaries .values ()]
440443 gt .lane = [l ._osi for l in self .map .lanes .values ()]
444+ if nanos in self .traffic_light_states :
445+ gt .traffic_light = self .traffic_light_states [nanos ]
441446 yield gt
442447
443448 @classmethod
444449 def from_osi_gts (cls , gts : list [betterosi .GroundTruth ], ** kwargs ):
445450 projs = []
451+ traffic_light_states = {}
446452
447453 gts , tmp_gts = itertools .tee (gts , 2 )
448454 first_gt = next (tmp_gts )
@@ -471,6 +477,9 @@ def get_gts():
471477 else None ,
472478 )
473479 )
480+
481+ traffic_light_states [total_nanos ] = gt .traffic_light
482+
474483 for mv in gt .moving_object :
475484 yield dict (
476485 total_nanos = total_nanos ,
@@ -500,7 +509,7 @@ def get_gts():
500509 )
501510
502511 df_mv = pl .DataFrame (get_gts (), schema = polars_schema ).sort (["total_nanos" , "idx" ])
503- return cls (df_mv , projections = projs , host_vehicle_idx = host_vehicle_idx , ** kwargs )
512+ return cls (df_mv , projections = projs , host_vehicle_idx = host_vehicle_idx , traffic_light_states = traffic_light_states , ** kwargs )
504513
505514 @classmethod
506515 def from_file (
0 commit comments