|
| 1 | +import polars as pl |
| 2 | +from dataclasses import dataclass, field |
| 3 | +from collections.abc import Callable |
| 4 | +import polars_st as st |
| 5 | +from .recording import Recording |
| 6 | +import graphlib |
| 7 | + |
| 8 | + |
| 9 | +@dataclass |
| 10 | +class Metric: |
| 11 | + compute_func: Callable[[pl.LazyFrame, ...], tuple[pl.LazyFrame, dict[str, pl.LazyFrame]]] |
| 12 | + computes_columns: list[str] = field(default_factory=list) |
| 13 | + computes_properties: list[str] = field(default_factory=list) |
| 14 | + requires_columns: list[str] = field(default_factory=list) |
| 15 | + requires_properties: list[str] = field(default_factory=list) |
| 16 | + |
| 17 | + def compute_lazy(self, df, **kwargs) -> tuple[pl.DataFrame, dict[str, pl.DataFrame]]: |
| 18 | + return self.compute_func(df, **kwargs) |
| 19 | + |
| 20 | + |
| 21 | +@dataclass |
| 22 | +class MetricManager: |
| 23 | + metrics: list[Metric] |
| 24 | + _dependencies: dict[int | str, list[int | str]] = field(init=False) |
| 25 | + _ordered_metrics: list[Metric] = field(init=False) |
| 26 | + |
| 27 | + def __post_init__(self): |
| 28 | + self._dependencies = { |
| 29 | + val: [i] |
| 30 | + for i, m in enumerate(self.metrics) |
| 31 | + for val in [f"column_{n}" for n in m.computes_columns] + [f"property_{n}" for n in m.computes_properties] |
| 32 | + } | { |
| 33 | + i: [f"column_{n}" for n in m.requires_columns] + [f"property_{n}" for n in m.requires_properties] |
| 34 | + for i, m in enumerate(self.metrics) |
| 35 | + } |
| 36 | + |
| 37 | + unresovled_dependencies = { |
| 38 | + k: v for k, vv in self._dependencies.items() for v in vv if v not in self._dependencies |
| 39 | + } |
| 40 | + if len(unresovled_dependencies) > 0: |
| 41 | + error_dict = {f"self.metrics[{k}]": v for k, v in unresovled_dependencies.items()} |
| 42 | + raise RuntimeError( |
| 43 | + f"There are columns and properties required by metrics, that are never computed: {error_dict}" |
| 44 | + ) |
| 45 | + |
| 46 | + ts = graphlib.TopologicalSorter(self._dependencies) |
| 47 | + self._ordered_metrics = [self.metrics[o] for o in ts.static_order() if isinstance(o, int)] |
| 48 | + |
| 49 | + def __repr__(self): |
| 50 | + return f"computes columns: {[c for m in self._ordered_metrics for c in m.computes_columns]} - computes properties {[p for m in self._ordered_metrics for p in m.computes_properties]}" |
| 51 | + |
| 52 | + def compute(self, r: Recording, *args, **kwargs) -> tuple[pl.DataFrame, dict[str, pl.DataFrame]]: |
| 53 | + if "polygon" not in r._df.columns: |
| 54 | + r._df = r._add_polygons(r._df) |
| 55 | + if "geometry" not in r._df.columns: |
| 56 | + r._df = r._df.with_columns(geometry=st.from_shapely("polygon")) |
| 57 | + |
| 58 | + df = pl.LazyFrame(r._df) |
| 59 | + properties = {} |
| 60 | + for m in self._ordered_metrics: |
| 61 | + df, new_p = m.compute_lazy(df, *args, **{k: properties[k] for k in m.requires_properties}, **kwargs) |
| 62 | + properties |= new_p |
| 63 | + res = pl.collect_all([df] + list(properties.values())) |
| 64 | + df, computed_props = res[0], res[1:] |
| 65 | + return df, {k: v for k, v in zip(properties.keys(), computed_props)} |
| 66 | + |
| 67 | + |
| 68 | +def add_driven_distance_and_vel(df, *args, **kwargs) -> tuple[pl.DataFrame, dict[str, pl.DataFrame]]: |
| 69 | + return df.with_columns( |
| 70 | + (pl.col("x").diff() ** 2 + pl.col("y").diff() ** 2) |
| 71 | + .sqrt() |
| 72 | + .over("idx") |
| 73 | + .fill_null(0.0) |
| 74 | + .cum_sum() |
| 75 | + .alias("distance_traveled"), |
| 76 | + (pl.col("vel_x") ** 2 + pl.col("vel_y") ** 2).sqrt().alias("vel"), |
| 77 | + ), {} |
| 78 | + |
| 79 | + |
| 80 | +drivenDistancenAndVel = Metric(computes_columns=["distance_traveled", "vel"], compute_func=add_driven_distance_and_vel) |
| 81 | + |
| 82 | + |
| 83 | +def get_timegaps(df, ego_id, *args, time_buffer=2e9, **kwargs): |
| 84 | + ego_df = df.filter(idx=ego_id) |
| 85 | + |
| 86 | + crossed = df.join(ego_df, how="cross", suffix="_ego") |
| 87 | + |
| 88 | + crossed = crossed.filter( |
| 89 | + (pl.col("total_nanos_ego") - time_buffer) <= pl.col("total_nanos"), |
| 90 | + (pl.col("total_nanos_ego") + time_buffer) >= pl.col("total_nanos"), |
| 91 | + pl.col("idx_ego") != pl.col("idx"), |
| 92 | + ) |
| 93 | + |
| 94 | + all_timegaps = ( |
| 95 | + crossed.filter(pl.col("geometry").st.intersects(pl.col("geometry_ego"))) |
| 96 | + .with_columns(timegap=(pl.col("total_nanos") - pl.col("total_nanos_ego")) / 1e9) |
| 97 | + .select( |
| 98 | + "idx_ego", "idx", "total_nanos_ego", "total_nanos", "timegap", "distance_traveled", "distance_traveled_ego" |
| 99 | + ) |
| 100 | + ) |
| 101 | + |
| 102 | + timegaps = ( |
| 103 | + all_timegaps.group_by("idx", "idx_ego", "total_nanos_ego") |
| 104 | + .agg( |
| 105 | + pl.col("timegap", "total_nanos", "distance_traveled", "distance_traveled_ego").get( |
| 106 | + pl.col("timegap").abs().arg_min() |
| 107 | + ), |
| 108 | + ) |
| 109 | + .sort("idx_ego", "idx", "total_nanos_ego") |
| 110 | + .select( |
| 111 | + "idx_ego", "idx", "total_nanos_ego", "timegap", "total_nanos", "distance_traveled", "distance_traveled_ego" |
| 112 | + ) |
| 113 | + ) |
| 114 | + min_timegaps = timegaps.group_by("idx_ego", "idx").agg( |
| 115 | + pl.col("timegap").get(pl.col("timegap").abs().arg_min()).alias("min_timegap") |
| 116 | + ) |
| 117 | + |
| 118 | + p_timegaps = ( |
| 119 | + crossed.join(timegaps, how="right", suffix="_overlap", on=["idx", "idx_ego"]) |
| 120 | + .with_columns( |
| 121 | + pl.when(pl.col("total_nanos") >= pl.col("total_nanos_overlap")) |
| 122 | + .then((pl.col("total_nanos_overlap") - pl.col("total_nanos")) / 1e9) |
| 123 | + .otherwise((pl.col("distance_traveled_overlap") - pl.col("distance_traveled")) / pl.col("vel")) |
| 124 | + .alias("time_to_overlap"), |
| 125 | + pl.when(pl.col("total_nanos_ego") >= pl.col("total_nanos_ego_overlap")) |
| 126 | + .then((pl.col("total_nanos_ego_overlap") - pl.col("total_nanos_ego")) / 1e9) |
| 127 | + .otherwise((pl.col("distance_traveled_ego_overlap") - pl.col("distance_traveled_ego")) / pl.col("vel_ego")) |
| 128 | + .alias("time_to_overlap_ego"), |
| 129 | + ) |
| 130 | + .with_columns( |
| 131 | + -( |
| 132 | + pl.col("time_to_overlap_ego") |
| 133 | + - pl.col("time_to_overlap") |
| 134 | + + (pl.col("total_nanos_ego") - pl.col("total_nanos")) / 1e9 |
| 135 | + ).alias("p_timegap") |
| 136 | + ) |
| 137 | + .group_by("idx_ego", "idx", "total_nanos_ego") |
| 138 | + .agg( |
| 139 | + pl.col("p_timegap", "total_nanos") |
| 140 | + .sort_by(pl.col("p_timegap").abs(), descending=False, nulls_last=True) |
| 141 | + .first() |
| 142 | + ) |
| 143 | + .sort("idx_ego", "idx", "total_nanos_ego") |
| 144 | + ) |
| 145 | + |
| 146 | + min_p_timegaps = p_timegaps.group_by("idx_ego", "idx").agg( |
| 147 | + pl.col("p_timegap").sort_by(pl.col("p_timegap").abs(), descending=False).first() |
| 148 | + ) |
| 149 | + |
| 150 | + return df, { |
| 151 | + "timegaps": timegaps, |
| 152 | + "min_timegaps": min_timegaps, |
| 153 | + "p_timegaps": p_timegaps, |
| 154 | + "min_p_timegaps": min_p_timegaps, |
| 155 | + } |
| 156 | + |
| 157 | + |
| 158 | +timegaps_and_p_timegaps = Metric( |
| 159 | + requires_columns=["distance_traveled", "vel"], |
| 160 | + compute_func=get_timegaps, |
| 161 | + computes_columns=[], |
| 162 | + computes_properties=["timegaps", "min_timegaps", "p_timegaps", "min_p_timegaps"], |
| 163 | +) |
0 commit comments