Skip to content

Commit 69ff15d

Browse files
Mimoun el GhaoutyMichaelSchuldes
authored andcommitted
traffic_light_states to Recording: init, to_osi_gts and from_osi_gts
1 parent b3ac63d commit 69ff15d

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

omega_prime/recording.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def __init__(
355355
host_vehicle_idx: int | None = None,
356356
validate=False,
357357
compute_polygons=False,
358+
traffic_light_states: dict | None = None,
358359
):
359360
if not isinstance(df, pl.DataFrame):
360361
df = pl.DataFrame(df, schema_overrides=polars_schema)
@@ -384,6 +385,8 @@ def __init__(
384385
.alias("acc")
385386
)
386387
self.projections = projections if projections is not None else []
388+
self.traffic_light_states = traffic_light_states if traffic_light_states is not None else {}
389+
387390
self._df = df
388391
self.map = map
389392
self._moving_objects = None
@@ -438,11 +441,14 @@ def to_osi_gts(self) -> list[betterosi.GroundTruth]:
438441
if self.map is not None and isinstance(self.map, MapOsi | MapOsiCenterline):
439442
gt.lane_boundary = [b._osi for b in self.map.lane_boundaries.values()]
440443
gt.lane = [l._osi for l in self.map.lanes.values()]
444+
if nanos in self.traffic_light_states:
445+
gt.traffic_light = self.traffic_light_states[nanos]
441446
yield gt
442447

443448
@classmethod
444449
def from_osi_gts(cls, gts: list[betterosi.GroundTruth], **kwargs):
445450
projs = []
451+
traffic_light_states = {}
446452

447453
gts, tmp_gts = itertools.tee(gts, 2)
448454
first_gt = next(tmp_gts)
@@ -471,6 +477,9 @@ def get_gts():
471477
else None,
472478
)
473479
)
480+
481+
traffic_light_states[total_nanos] = gt.traffic_light
482+
474483
for mv in gt.moving_object:
475484
yield dict(
476485
total_nanos=total_nanos,
@@ -500,7 +509,7 @@ def get_gts():
500509
)
501510

502511
df_mv = pl.DataFrame(get_gts(), schema=polars_schema).sort(["total_nanos", "idx"])
503-
return cls(df_mv, projections=projs, host_vehicle_idx=host_vehicle_idx, **kwargs)
512+
return cls(df_mv, projections=projs, host_vehicle_idx=host_vehicle_idx, traffic_light_states=traffic_light_states, **kwargs)
504513

505514
@classmethod
506515
def from_file(

0 commit comments

Comments
 (0)