Skip to content

Commit 1def9a7

Browse files
switch from pandas to polars
1 parent e223fc9 commit 1def9a7

File tree

9 files changed

+188
-127
lines changed

9 files changed

+188
-127
lines changed

omega_prime/cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pathlib import Path
22
from typing import Annotated
33

4-
import pandas as pd
4+
import polars as pl
55
import typer
66

77
import omega_prime
@@ -40,7 +40,7 @@ def from_csv(
4040
validate: bool = True,
4141
skip_odr_parse: bool = False,
4242
):
43-
df = pd.read_csv(input)
43+
df = pl.read_csv(input)
4444
r = omega_prime.Recording(df, validate=validate)
4545
if odr is not None:
4646
r.map = omega_prime.asam_odr.MapOdr.from_file(odr, skip_parse=skip_odr_parse)
@@ -51,7 +51,7 @@ def from_csv(
5151
def validate(
5252
input: Annotated[Path, typer.Argument(help="Path to omega file to validate", exists=True, dir_okay=False)],
5353
):
54-
omega_prime.Recording.from_file(input)
54+
omega_prime.Recording.from_file(input, validate=True)
5555
print(f"File {input} is valid.")
5656

5757

omega_prime/converters/lxd.py

Lines changed: 51 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import multiprocessing as mp
1010
from ..asam_odr import MapOdr
1111
from ..recording import Recording
12-
12+
import polars as pl
1313

1414
__all__ = ["convert_lxd"]
1515
logger.configure(handlers=[{"sink": sys.stdout, "level": "WARNING"}])
@@ -29,10 +29,6 @@
2929
pedestrians = {"pedestrian": betterosi.MovingObjectType.TYPE_PEDESTRIAN}
3030

3131

32-
def wrap_angle(angle):
33-
return (angle + np.pi) % (2 * np.pi) - np.pi
34-
35-
3632
class DatasetConverter:
3733
def __init__(self, dataset_dir: Path) -> None:
3834
self._dataset = Dataset(dataset_dir)
@@ -45,62 +41,70 @@ def get_recording_opendrive_path(self, recording_id: int) -> Path:
4541

4642
def rec2df(self, recording_id):
4743
rec = self._dataset.get_recording(recording_id)
48-
meta = rec._tracks_meta_data
4944
dt = 1 / rec.get_meta_data("frameRate")
50-
meta["type"] = (
51-
meta["class"]
52-
.apply(
53-
lambda x: betterosi.MovingObjectType.TYPE_VEHICLE
54-
if x in vehicles
55-
else betterosi.MovingObjectType.TYPE_PEDESTRIAN
45+
46+
meta = rec._tracks_meta_data
47+
meta = meta.with_columns(
48+
pl.col("class")
49+
.map_elements(
50+
(
51+
lambda x: betterosi.MovingObjectType.TYPE_VEHICLE
52+
if x in vehicles
53+
else betterosi.MovingObjectType.TYPE_PEDESTRIAN
54+
),
55+
return_dtype=int,
5656
)
57-
.values
58-
)
59-
meta.loc[:, "role"] = meta["class"].apply(
60-
lambda x: betterosi.MovingObjectVehicleClassificationRole.ROLE_CIVIL if x in vehicles else -1
57+
.alias("type"),
58+
pl.col("class")
59+
.map_elements(
60+
(lambda x: betterosi.MovingObjectVehicleClassificationRole.ROLE_CIVIL if x in vehicles else -1),
61+
return_dtype=int,
62+
)
63+
.alias("role"),
64+
pl.col("class")
65+
.map_elements((lambda x: vehicles[x] if x in vehicles else -1), return_dtype=int)
66+
.alias("subtype"),
6167
)
62-
meta["subtype"] = meta["class"].apply(lambda x: vehicles[x] if x in vehicles else -1)
63-
meta = meta.rename(columns={"trackId": "idx"})
68+
meta = meta.rename({"trackId": "idx"})
6469

6570
tracks = rec._get_tracks_data()
6671
tracks = tracks.rename(
67-
columns={
72+
{
6873
"xCenter": "x",
6974
"yCenter": "y",
7075
"xVelocity": "vel_x",
7176
"yVelocity": "vel_y",
7277
"xAcceleration": "acc_x",
7378
"yAcceleration": "acc_y",
7479
"trackId": "idx",
75-
"width": "width",
76-
"length": "length",
80+
# "width": "width",
81+
# "length": "length",
7782
}
7883
)
79-
for k in ["acc_z", "z", "vel_z", "roll", "pitch"]:
80-
tracks[k] = 0.0
81-
tracks["height"] = 2.0
82-
tracks["yaw"] = wrap_angle(np.deg2rad(tracks["heading"]))
83-
tracks["total_nanos"] = np.array(tracks["frame"] * dt * NANOS_PER_SEC, dtype=int)
84-
tracks = tracks.merge(meta[["role", "idx", "type", "subtype"]], on="idx")
85-
86-
tracks.loc[
87-
(tracks["type"] == betterosi.MovingObjectType.TYPE_VEHICLE)
88-
& (tracks["subtype"] == betterosi.MovingObjectVehicleClassificationType.TYPE_BICYCLE),
89-
"width",
90-
] = 0.8
91-
tracks.loc[
92-
(tracks["type"] == betterosi.MovingObjectType.TYPE_VEHICLE)
93-
& (tracks["subtype"] == betterosi.MovingObjectVehicleClassificationType.TYPE_BICYCLE),
94-
"length",
95-
] = 2
96-
tracks.loc[
97-
(tracks["type"] == betterosi.MovingObjectType.TYPE_VEHICLE)
98-
& (tracks["subtype"] == betterosi.MovingObjectVehicleClassificationType.TYPE_BICYCLE),
99-
"height",
100-
] = 1.9
101-
tracks.loc[(tracks["type"] == betterosi.MovingObjectType.TYPE_PEDESTRIAN), "width"] = 0.5
102-
tracks.loc[(tracks["type"] == betterosi.MovingObjectType.TYPE_PEDESTRIAN), "length"] = 0.5
103-
tracks.loc[(tracks["type"] == betterosi.MovingObjectType.TYPE_PEDESTRIAN), "height"] = 1.8
84+
tracks = tracks.join(meta.select(["idx", "role", "type", "subtype"]), on="idx", how="left")
85+
is_vehicle = pl.col("type") == betterosi.MovingObjectType.TYPE_VEHICLE
86+
is_bicycle = pl.col("subtype") == betterosi.MovingObjectVehicleClassificationType.TYPE_BICYCLE
87+
is_pedestrian = pl.col("type") == betterosi.MovingObjectType.TYPE_PEDESTRIAN
88+
tracks = tracks.with_columns(
89+
[pl.lit(0.0).alias(k) for k in ["acc_z", "z", "vel_z", "roll", "pitch"]]
90+
+ [
91+
(((pl.col("heading") + np.pi) % (2 * np.pi)) - np.pi).alias("yaw"),
92+
(pl.col("frame") * dt * NANOS_PER_SEC).cast(pl.Int64).alias("total_nanos"),
93+
pl.when(is_vehicle & is_bicycle)
94+
.then(0.8)
95+
.when(is_pedestrian)
96+
.then(0.5)
97+
.otherwise(pl.col("width"))
98+
.alias("width"),
99+
pl.when(is_vehicle & is_bicycle)
100+
.then(2.0)
101+
.when(is_pedestrian)
102+
.then(0.5)
103+
.otherwise(pl.col("length"))
104+
.alias("length"),
105+
pl.when(is_vehicle & is_bicycle).then(1.9).when(is_pedestrian).then(1.8).otherwise(2.0).alias("height"),
106+
]
107+
)
104108

105109
return tracks
106110

@@ -109,7 +113,7 @@ def convert_recording(args):
109113
converter, recording_id, out_filename = args
110114
tracks = converter.rec2df(recording_id)
111115
xodr_path = converter.get_recording_opendrive_path(recording_id)
112-
rec = Recording(df=tracks, map=MapOdr.from_file(xodr_path))
116+
rec = Recording(df=tracks, map=MapOdr.from_file(xodr_path), validate=False)
113117
rec.to_mcap(out_filename)
114118

115119

omega_prime/map.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,6 @@ def create(cls, lane: betterosi.Lane):
9595
subtype=betterosi.LaneClassificationSubtype(lane.classification.subtype),
9696
successor_ids=[],
9797
predecessor_ids=[],
98-
right_boundary_id=None,
99-
left_boundary_id=None,
10098
)
10199

102100
def plot(self, ax: plt.Axes):

omega_prime/recording.py

Lines changed: 45 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44

55
import betterosi
66
import numpy as np
7-
import pandas as pd
87
import pyproj
98
import shapely
109
from matplotlib import pyplot as plt
1110
from matplotlib.patches import Polygon as PltPolygon
11+
import pandas as pd
12+
import polars as pl
1213

1314
import pandera as pa
1415
import pandera.extensions as extensions
@@ -165,7 +166,7 @@ def __init__(self, recording, idx):
165166
self.idx = int(idx)
166167
self._recording = recording
167168

168-
self._df = self._recording._df.loc[self._recording._df["idx"] == self.idx]
169+
self._df = self._recording._df.filter(pl.col("idx") == self.idx)
169170

170171
for k in [
171172
"x",
@@ -182,23 +183,23 @@ def __init__(self, recording, idx):
182183
"pitch",
183184
"polygon",
184185
]:
185-
setattr(self, k, self._df.loc[:, k].values)
186-
self.vel = np.linalg.norm([self.vel_x, self.vel_y], axis=0)
187-
self.timestamps = self._df.loc[:, "total_nanos"].values / 1e9
186+
setattr(self, k, self._df[:, k])
187+
self.vel = np.linalg.norm(self._df["vel_x", "vel_y"], axis=0)
188+
self.timestamps = self._df["total_nanos"] / 1e9
188189
for k in ["length", "width", "height"]:
189-
setattr(self, f"{k}s", self._df.loc[:, k].values)
190-
setattr(self, k, np.mean(self._df.loc[:, k].values))
190+
setattr(self, f"{k}s", self._df[k])
191+
setattr(self, k, self._df[k].mean())
191192

192-
self.type = betterosi.MovingObjectType(self._df.loc[:, "type"].iloc[0])
193-
subtype_int = self._df.loc[:, "subtype"].iloc[0]
193+
self.type = betterosi.MovingObjectType(self._df["type"][0])
194+
subtype_int = self._df["subtype"][0]
194195
self.subtype = betterosi.MovingObjectVehicleClassificationType(subtype_int) if subtype_int != -1 else None
195-
role_int = self._df.loc[:, "role"].iloc[0]
196+
role_int = self._df["role"][0]
196197
self.role = betterosi.MovingObjectVehicleClassificationRole(role_int) if role_int != -1 else None
197-
self.birth = int(self._df.loc[:, "frame"].iloc[0])
198-
self.end = int(self._df.loc[:, "frame"].iloc[-1])
198+
self.birth = int(self._df["frame"][0])
199+
self.end = int(self._df["frame"][-1])
199200

200-
def set(self, k, val):
201-
self._recording._df.loc[self._recording._df["idx"] == self.idx, k] = val
201+
# def set(self, k, val):
202+
# self._recording._df.loc[self._recording._df["idx"] == self.idx, k] = val
202203

203204
@property
204205
def nanos(self):
@@ -209,7 +210,7 @@ def plot(self, ax: plt.Axes):
209210
pass
210211

211212
def plot_mv_frame(self, ax: plt.Axes, frame: int):
212-
polys = self._df[self._df["frame"] == frame]["polygon"].values
213+
polys = self._df.filter(pl.col("frame") == frame)["polygon"]
213214
for p in polys:
214215
ax.add_patch(PltPolygon(p.exterior.coords, fc="red", alpha=0.2))
215216

@@ -241,10 +242,10 @@ def _get_polygons(df):
241242

242243
@staticmethod
243244
def get_moving_object_ground_truth(
244-
nanos: int, df: pd.DataFrame, host_vehicle=None, validate=True
245+
nanos: int, df: pl.DataFrame, host_vehicle=None, validate=False
245246
) -> betterosi.GroundTruth:
246247
if validate:
247-
recording_moving_object_schema.validate(df, lazy=True)
248+
recording_moving_object_schema.validate(df.to_pandas(), lazy=True)
248249

249250
def get_object(row):
250251
return betterosi.MovingObject(
@@ -262,7 +263,7 @@ def get_object(row):
262263
),
263264
)
264265

265-
mvs = list(df.apply(get_object, axis=1).values)
266+
mvs = [get_object(r) for r in df.iter_rows(named=True)]
266267
gt = betterosi.GroundTruth(
267268
version=betterosi.InterfaceVersion(version_major=3, version_minor=7, version_patch=9),
268269
timestamp=betterosi.Timestamp(seconds=int(nanos // 1_000_000_000), nanos=int(nanos % 1_000_000_000)),
@@ -273,14 +274,14 @@ def get_object(row):
273274
)
274275
return gt
275276

276-
def __init__(self, df, map=None, projections=None, host_vehicle=None, validate=True):
277+
def __init__(self, df, map=None, projections=None, host_vehicle=None, validate=False):
277278
if validate:
278-
recording_moving_object_schema.validate(df, lazy=True)
279+
recording_moving_object_schema.validate(df.to_pandas(), lazy=True)
279280
super().__init__()
280-
self.nanos2frame = {n: i for i, n in enumerate(df.total_nanos.unique())}
281-
df["frame"] = df.total_nanos.map(self.nanos2frame)
282-
if "polygon" not in df.columns:
283-
df["polygon"] = self._get_polygons(df)
281+
self.nanos2frame = {n: i for i, n in enumerate(df["total_nanos"].unique())}
282+
mapping = pl.DataFrame({"total_nanos": list(self.nanos2frame.keys()), "frame": list(self.nanos2frame.values())})
283+
df = df.join(mapping, on="total_nanos", how="left")
284+
df = df.with_columns([pl.Series(name="polygon", values=self._get_polygons(df))])
284285
self.projections = projections
285286
self._df = df
286287
self.map = map
@@ -290,7 +291,7 @@ def __init__(self, df, map=None, projections=None, host_vehicle=None, validate=T
290291
def to_osi_gts(self) -> list[betterosi.GroundTruth]:
291292
gts = [
292293
self.get_moving_object_ground_truth(nanos, group_df, host_vehicle=self.host_vehicle, validate=False)
293-
for nanos, group_df in self._df.groupby("total_nanos")
294+
for [nanos], group_df in self._df.group_by("total_nanos")
294295
]
295296

296297
if self.map is not None and isinstance(self.map, MapOsi | MapOsiCenterline):
@@ -351,11 +352,11 @@ def from_osi_gts(cls, gts: list[betterosi.GroundTruth], validate: bool = True):
351352
)
352353
for mv in gt.moving_object
353354
]
354-
df_mv = pd.DataFrame(mvs).sort_values(by=["total_nanos", "idx"]).reset_index(drop="index")
355+
df_mv = pl.DataFrame(mvs).sort(["total_nanos", "idx"])
355356
return cls(df_mv, projections=projs, validate=validate)
356357

357358
@classmethod
358-
def from_file(cls, filepath, xodr_path: str | None = None, validate: bool = True, skip_odr_parse: bool = False):
359+
def from_file(cls, filepath, xodr_path: str | None = None, validate: bool = False, skip_odr_parse: bool = False):
359360
gts = betterosi.read(
360361
filepath,
361362
return_ground_truth=True,
@@ -396,43 +397,47 @@ def to_mcap(self, filepath):
396397

397398
def to_hdf(self, filename, key="moving_object"):
398399
#!pip install tables
399-
self._df.drop(columns=["polygon", "frame"]).to_hdf(filename, key=key)
400+
self._df.drop(columns=["polygon", "frame"]).to_pandas().to_hdf(filename, key=key)
400401

401402
@classmethod
402403
def from_hdf(cls, filename, key="moving_object"):
403-
df = pd.read_hdf(filename, key=key)
404+
df = pl.DataFrame(pd.read_hdf(filename, key=key))
404405
return cls(df, map=None, host_vehicle=None)
405406

406407
def interpolate(self, new_nanos: list[int] | None = None, hz: float | None = None):
407408
df = self._df
408409
if new_nanos is None and hz is None:
409410
new_nanos = np.linspace(
410-
df.total_nanos.min(), df.total_nanos.max(), df.frame.max() - df.frame.min(), dtype=int
411+
df["total_nanos"].min(), df["total_nanos"].max(), df["frame"].max() - df["frame"].min(), dtype=int
411412
)
412413
elif hz is not None:
413-
step = 1_000_000_000 * hz
414-
new_nanos = np.arange(start=df.total_nanos.min(), stop=df.total_nanos.max() + 1, step=step, dtype=int)
414+
step = 1_000_000_000 / hz
415+
new_nanos = np.arange(start=df["total_nanos"].min(), stop=df["total_nanos"].max() + 1, step=step, dtype=int)
415416
else:
416417
new_nanos = np.array(new_nanos)
417418
new_dfs = []
418-
for idx, track_df in df.groupby("idx"):
419+
for [idx], track_df in df.group_by("idx"):
419420
track_data = {}
420421
track_new_nanos = new_nanos[
421-
track_df.frame.min() - df.frame.min() : track_df.frame.max() - df.frame.min() + 1
422+
track_df["frame"].min() - df["frame"].min() : track_df["frame"].max() - df["frame"].min() + 1
422423
]
423424
for c in ["x", "y", "z", "vel_x", "vel_y", "vel_z", "acc_x", "acc_y", "acc_z", "length", "width", "height"]:
424425
track_data[c] = np.interp(track_new_nanos, track_df["total_nanos"], track_df[c])
425426
for c in ["type", "subtype", "role"]:
426-
track_data[c] = nearest_interp(track_new_nanos, track_df["total_nanos"].values, track_df[c].values)
427+
track_data[c] = nearest_interp(
428+
track_new_nanos, track_df["total_nanos"].to_numpy(), track_df[c].to_numpy()
429+
)
427430
for c in ["roll", "pitch", "yaw"]:
428431
track_data[c] = np.mod(
429432
np.interp(track_new_nanos, track_df["total_nanos"], np.unwrap(track_df[c], period=np.pi)), np.pi
430433
)
431-
new_track_df = pd.DataFrame(track_data)
432-
new_track_df["idx"] = idx
433-
new_track_df["total_nanos"] = track_new_nanos
434+
new_track_df = pl.DataFrame(track_data)
435+
new_track_df = new_track_df.with_columns(
436+
pl.Series(name="idx", values=np.ones_like(track_new_nanos) * idx),
437+
pl.Series(name="total_nanos", values=track_new_nanos),
438+
)
434439
new_dfs.append(new_track_df)
435-
new_df = pd.concat(new_dfs)
440+
new_df = pl.concat(new_dfs)
436441
return self.__init__(new_df, self.map, self.host_vehicle)
437442

438443
def plot(self, ax=None, legend=False) -> plt.Axes:
@@ -452,6 +457,6 @@ def plot_frame(self, frame: int, ax=None):
452457
return ax
453458

454459
def plot_mv_frame(self, ax: plt.Axes, frame: int):
455-
polys = self._df[self._df["frame"] == frame]["polygon"].values
460+
polys = self._df.filter(pl.col("frame") == frame)["polygon"]
456461
for p in polys:
457462
ax.add_patch(PltPolygon(p.exterior.coords, fc="red"))

0 commit comments

Comments
 (0)