Skip to content

Commit dcefbff

Browse files
ruff format
1 parent a6ca82c commit dcefbff

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

omega_prime/cli.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,13 @@ def visualize(
102102
height: int = 500,
103103
width: int = 1600,
104104
start_frame: int = 0,
105-
end_frame: int = -1
105+
end_frame: int = -1,
106106
):
107107
import altair as alt
108+
108109
alt.renderers.enable("browser")
109110
r = omega_prime.Recording.from_file(input, validate=False, parse_map=True)
110-
r.plot_altair(start_frame=start_frame, end_frame=end_frame,height=height, width=width).show()
111+
r.plot_altair(start_frame=start_frame, end_frame=end_frame, height=height, width=width).show()
111112

112113

113114
def main():

omega_prime/recording.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def __init__(
386386
)
387387
self.projections = projections if projections is not None else []
388388
self.traffic_light_states = traffic_light_states if traffic_light_states is not None else {}
389-
389+
390390
self._df = df
391391
self.map = map
392392
self._moving_objects = None
@@ -479,7 +479,7 @@ def get_gts():
479479
)
480480

481481
traffic_light_states[total_nanos] = gt.traffic_light
482-
482+
483483
for mv in gt.moving_object:
484484
yield dict(
485485
total_nanos=total_nanos,
@@ -509,7 +509,13 @@ def get_gts():
509509
)
510510

511511
df_mv = pl.DataFrame(get_gts(), schema=polars_schema).sort(["total_nanos", "idx"])
512-
return cls(df_mv, projections=projs, host_vehicle_idx=host_vehicle_idx, traffic_light_states=traffic_light_states, **kwargs)
512+
return cls(
513+
df_mv,
514+
projections=projs,
515+
host_vehicle_idx=host_vehicle_idx,
516+
traffic_light_states=traffic_light_states,
517+
**kwargs,
518+
)
513519

514520
@classmethod
515521
def from_file(
@@ -702,7 +708,15 @@ def to_parquet(self, filename):
702708
pq.write_table(t, filename)
703709

704710
def plot_altair(
705-
self, start_frame=0, end_frame=-1, plot_map=True, plot_map_polys=True, metric_column=None, idx=None, height=None,width=None
711+
self,
712+
start_frame=0,
713+
end_frame=-1,
714+
plot_map=True,
715+
plot_map_polys=True,
716+
metric_column=None,
717+
idx=None,
718+
height=None,
719+
width=None,
706720
):
707721
if "polygon" not in self._df.columns:
708722
self._df = self._add_polygons(self._df)

0 commit comments

Comments
 (0)