|
9 | 9 | from matplotlib.patches import Polygon as PltPolygon |
10 | 10 | import pandas as pd |
11 | 11 | import polars as pl |
| 12 | +import altair as alt |
12 | 13 |
|
13 | 14 | import pandera as pa |
14 | 15 | import pandera.extensions as extensions |
|
19 | 20 | from .map import MapOsi, ProjectionOffset, MapOsiCenterline |
20 | 21 | import itertools |
21 | 22 | from functools import partial |
| 23 | +import altair as alt |
| 24 | +import polars as pl |
| 25 | +import polars_st as st |
22 | 26 |
|
23 | 27 | pi_valued = pa.Check.between(-np.pi, np.pi) |
24 | 28 |
|
@@ -228,7 +232,7 @@ class Recording: |
228 | 232 |
|
229 | 233 | @staticmethod |
230 | 234 | def _add_polygons(df): |
231 | | - if 'polygon' not in df.columns: |
| 235 | + if "polygon" not in df.columns: |
232 | 236 | ar = ( |
233 | 237 | df[:] |
234 | 238 | .select( |
@@ -586,3 +590,84 @@ def to_parquet(self, filename): |
586 | 590 | else: |
587 | 591 | t = t.cast(t.schema.with_metadata(proj_dict | {})) |
588 | 592 | pq.write_table(t, filename) |
| 593 | + |
| 594 | + def plot_altair(self, start_frame=0, end_frame=-1, plot_map=True, metric_column=None, idx=None): |
| 595 | + if "polygon" not in self._df.columns: |
| 596 | + self._df = self._add_polygons(self._df) |
| 597 | + if "geometry" not in self._df.columns: |
| 598 | + self._df = self._df.with_columns(geometry=st.from_shapely("polygon")) |
| 599 | + if not hasattr(self, "_map_df") and plot_map: |
| 600 | + self._map_df = pl.DataFrame( |
| 601 | + [ |
| 602 | + pl.Series(name="polygon", values=[l.polygon for l in self.map.lanes.values()]), |
| 603 | + pl.Series(name="idx", values=[i for i, _ in enumerate(self.map.lanes.keys())]), |
| 604 | + pl.Series(name="type", values=[o.type.name for o in self.map.lanes.values()]), |
| 605 | + ] |
| 606 | + ) |
| 607 | + self._map_df = self._map_df.with_columns(geometry=st.from_shapely("polygon")) |
| 608 | + |
| 609 | + if end_frame != -1: |
| 610 | + df = self._df.filter(pl.col("frame") < end_frame, pl.col("frame") >= start_frame) |
| 611 | + else: |
| 612 | + df = self._df.filter(pl.col("frame") >= start_frame) |
| 613 | + |
| 614 | + [frame_min], [frame_max] = df.select( |
| 615 | + pl.col("frame").min().alias("min"), |
| 616 | + pl.col("frame").max().alias("max"), |
| 617 | + )[0] |
| 618 | + slider = alt.binding_range(min=frame_min, max=frame_max, step=1, name="frame") |
| 619 | + op_var = alt.param(value=0, bind=slider) |
| 620 | + |
| 621 | + [xmin], [xmax], [ymin], [ymax] = df.select( |
| 622 | + pl.col("x").min().alias("xmin"), |
| 623 | + pl.col("x").max().alias("xmax"), |
| 624 | + pl.col("y").min().alias("ymin"), |
| 625 | + pl.col("y").max().alias("ymax"), |
| 626 | + )[0] |
| 627 | + pov = { |
| 628 | + "type": "Feature", |
| 629 | + "geometry": { |
| 630 | + "type": "Polygon", |
| 631 | + "coordinates": [[[xmax, ymax], [xmax, ymin], [xmin, ymin], [xmin, ymax], [xmax, ymax]]], |
| 632 | + }, |
| 633 | + "properties": {}, |
| 634 | + } |
| 635 | + map = ( |
| 636 | + self._map_df["geometry", "idx", "type"] |
| 637 | + .st.plot(color="green", fillOpacity=0.4) |
| 638 | + .encode(tooltip=["properties.idx:N", "properties.type:O"]) |
| 639 | + ) |
| 640 | + mvs = ( |
| 641 | + df["geometry", "idx", "frame", "type"] |
| 642 | + .st.plot() |
| 643 | + .encode( |
| 644 | + tooltip=["properties.idx:N", "properties.frame:N", "properties.type:O"], |
| 645 | + color=alt.value("blue") |
| 646 | + if idx is None |
| 647 | + else alt.when(alt.FieldEqualPredicate(equal=idx, field="properties.idx")) |
| 648 | + .then(alt.value("red")) |
| 649 | + .otherwise(alt.value("blue")), |
| 650 | + ) |
| 651 | + .transform_filter(alt.FieldEqualPredicate(field="properties.frame", equal=op_var)) |
| 652 | + ) |
| 653 | + |
| 654 | + map_view = ( |
| 655 | + (map + mvs) |
| 656 | + .project("identity", reflectY=True, fit=pov) |
| 657 | + .properties(height=int(ymax - ymin) * 3, width=int(xmax - xmin) * 3, title="Map") |
| 658 | + ) |
| 659 | + view = map_view |
| 660 | + if metric_column is not None and idx is not None: |
| 661 | + metric = ( |
| 662 | + df["idx", metric_column, "frame"] |
| 663 | + .filter(idx=idx) |
| 664 | + .plot.line(x="frame", y=metric_column, color=alt.value("red")) |
| 665 | + .properties(title=f"{metric_column} of object {idx}") |
| 666 | + ) |
| 667 | + vertline = ( |
| 668 | + alt.Chart() |
| 669 | + .mark_rule() |
| 670 | + .encode(x=alt.datum(op_var, type="quantitative", scale=alt.Scale(domain=[frame_min, frame_max]))) |
| 671 | + ) |
| 672 | + view = view | (metric + vertline) |
| 673 | + return view.add_params(op_var) |
0 commit comments