Skip to content

Commit e19688d

Browse files
add interactive plotting with altair
1 parent 7b60fa2 commit e19688d

File tree

4 files changed

+304
-116
lines changed

4 files changed

+304
-116
lines changed

omega_prime/map_odr.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,16 +130,17 @@ def from_file(
130130
return cls.create(
131131
odr_xml=map.open_drive_xml_content, name=map.map_reference, step_size=step_size, parse=parse
132132
)
133+
133134
@property
134135
def lanes(self):
135136
if self._lanes is None:
136137
self.parse()
137138
return self._lanes
138-
139+
139140
@lanes.setter
140141
def lanes(self, val):
141142
self._lanes = val
142-
143+
143144
@property
144145
def lane_boundaries(self):
145146
if self._lane_boundaries is None:
@@ -149,7 +150,7 @@ def lane_boundaries(self):
149150
@lane_boundaries.setter
150151
def lane_boundaries(self, val):
151152
self._lane_boundaries = val
152-
153+
153154
@classmethod
154155
def create(cls, odr_xml, name, step_size=0.01, parse: bool = False):
155156
self = cls(odr_xml=odr_xml, name=name, step_size=step_size, lanes={}, lane_boundaries={})

omega_prime/recording.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from matplotlib.patches import Polygon as PltPolygon
1010
import pandas as pd
1111
import polars as pl
12+
import altair as alt
1213

1314
import pandera as pa
1415
import pandera.extensions as extensions
@@ -19,6 +20,9 @@
1920
from .map import MapOsi, ProjectionOffset, MapOsiCenterline
2021
import itertools
2122
from functools import partial
23+
import altair as alt
24+
import polars as pl
25+
import polars_st as st
2226

2327
pi_valued = pa.Check.between(-np.pi, np.pi)
2428

@@ -228,7 +232,7 @@ class Recording:
228232

229233
@staticmethod
230234
def _add_polygons(df):
231-
if 'polygon' not in df.columns:
235+
if "polygon" not in df.columns:
232236
ar = (
233237
df[:]
234238
.select(
@@ -586,3 +590,84 @@ def to_parquet(self, filename):
586590
else:
587591
t = t.cast(t.schema.with_metadata(proj_dict | {}))
588592
pq.write_table(t, filename)
593+
594+
def plot_altair(self, start_frame=0, end_frame=-1, plot_map=True, metric_column=None, idx=None):
595+
if "polygon" not in self._df.columns:
596+
self._df = self._add_polygons(self._df)
597+
if "geometry" not in self._df.columns:
598+
self._df = self._df.with_columns(geometry=st.from_shapely("polygon"))
599+
if not hasattr(self, "_map_df") and plot_map:
600+
self._map_df = pl.DataFrame(
601+
[
602+
pl.Series(name="polygon", values=[l.polygon for l in self.map.lanes.values()]),
603+
pl.Series(name="idx", values=[i for i, _ in enumerate(self.map.lanes.keys())]),
604+
pl.Series(name="type", values=[o.type.name for o in self.map.lanes.values()]),
605+
]
606+
)
607+
self._map_df = self._map_df.with_columns(geometry=st.from_shapely("polygon"))
608+
609+
if end_frame != -1:
610+
df = self._df.filter(pl.col("frame") < end_frame, pl.col("frame") >= start_frame)
611+
else:
612+
df = self._df.filter(pl.col("frame") >= start_frame)
613+
614+
[frame_min], [frame_max] = df.select(
615+
pl.col("frame").min().alias("min"),
616+
pl.col("frame").max().alias("max"),
617+
)[0]
618+
slider = alt.binding_range(min=frame_min, max=frame_max, step=1, name="frame")
619+
op_var = alt.param(value=0, bind=slider)
620+
621+
[xmin], [xmax], [ymin], [ymax] = df.select(
622+
pl.col("x").min().alias("xmin"),
623+
pl.col("x").max().alias("xmax"),
624+
pl.col("y").min().alias("ymin"),
625+
pl.col("y").max().alias("ymax"),
626+
)[0]
627+
pov = {
628+
"type": "Feature",
629+
"geometry": {
630+
"type": "Polygon",
631+
"coordinates": [[[xmax, ymax], [xmax, ymin], [xmin, ymin], [xmin, ymax], [xmax, ymax]]],
632+
},
633+
"properties": {},
634+
}
635+
map = (
636+
self._map_df["geometry", "idx", "type"]
637+
.st.plot(color="green", fillOpacity=0.4)
638+
.encode(tooltip=["properties.idx:N", "properties.type:O"])
639+
)
640+
mvs = (
641+
df["geometry", "idx", "frame", "type"]
642+
.st.plot()
643+
.encode(
644+
tooltip=["properties.idx:N", "properties.frame:N", "properties.type:O"],
645+
color=alt.value("blue")
646+
if idx is None
647+
else alt.when(alt.FieldEqualPredicate(equal=idx, field="properties.idx"))
648+
.then(alt.value("red"))
649+
.otherwise(alt.value("blue")),
650+
)
651+
.transform_filter(alt.FieldEqualPredicate(field="properties.frame", equal=op_var))
652+
)
653+
654+
map_view = (
655+
(map + mvs)
656+
.project("identity", reflectY=True, fit=pov)
657+
.properties(height=int(ymax - ymin) * 3, width=int(xmax - xmin) * 3, title="Map")
658+
)
659+
view = map_view
660+
if metric_column is not None and idx is not None:
661+
metric = (
662+
df["idx", metric_column, "frame"]
663+
.filter(idx=idx)
664+
.plot.line(x="frame", y=metric_column, color=alt.value("red"))
665+
.properties(title=f"{metric_column} of object {idx}")
666+
)
667+
vertline = (
668+
alt.Chart()
669+
.mark_rule()
670+
.encode(x=alt.datum(op_var, type="quantitative", scale=alt.Scale(domain=[frame_min, frame_max])))
671+
)
672+
view = view | (metric + vertline)
673+
return view.add_params(op_var)

tutorial.ipynb

Lines changed: 187 additions & 85 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)