Skip to content

Commit a1f289f

Browse files
oliverkallenbergkallenberg
andauthored
Add projection utility and improve standard plots (#21)
* Add bool to apply the proj information right after loding * Add functionality to plot function * Use better marker (rectangle) for traffic lights * Formatting * Check if the recording has a projection defined * Change to OP version 0.2.2 * Set apply_proj parameter to false in specific examples * Set apply_proj to False * Add execution count to metadata to pass pipeline * Remove old if statement --------- Co-authored-by: kallenberg <oliver.kallenberg@ika.rwth-aachen.de>
1 parent 814a022 commit a1f289f

File tree

4 files changed

+50
-14
lines changed

4 files changed

+50
-14
lines changed

docs/notebooks/tutorial.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@
2727
},
2828
{
2929
"cell_type": "code",
30-
"execution_count": 10,
30+
"execution_count": null,
3131
"metadata": {},
3232
"outputs": [],
3333
"source": [
3434
"import omega_prime\n",
3535
"\n",
3636
"r = omega_prime.Recording.from_file(\n",
37-
" \"../../example_files/pedestrian.osi\", map_path=\"../../example_files/fabriksgatan.xodr\"\n",
37+
" \"../../example_files/pedestrian.osi\", map_path=\"../../example_files/fabriksgatan.xodr\", apply_proj=False\n",
3838
")\n",
3939
"r.to_mcap(\"example.mcap\")"
4040
]
@@ -194,7 +194,7 @@
194194
}
195195
],
196196
"source": [
197-
"r = omega_prime.Recording.from_file(\"example.mcap\", parse_map=True)\n",
197+
"r = omega_prime.Recording.from_file(\"example.mcap\", parse_map=True, apply_proj=False)\n",
198198
"ax = r.plot()\n",
199199
"ax.set_xlim(-20, 50)\n",
200200
"ax.set_ylim(-75, 20)"

docs/notebooks/tutorial_locator.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
},
2020
{
2121
"cell_type": "code",
22-
"execution_count": 2,
22+
"execution_count": null,
2323
"metadata": {},
2424
"outputs": [],
2525
"source": [
2626
"r = omega_prime.Recording.from_file(\n",
27-
" \"../../example_files/pedestrian.osi\", map_path=\"../../example_files/fabriksgatan.xodr\"\n",
27+
" \"../../example_files/pedestrian.osi\", map_path=\"../../example_files/fabriksgatan.xodr\", apply_proj=False\n",
2828
")"
2929
]
3030
},

omega_prime/recording.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,7 @@ def from_file(
584584
validate: bool = False,
585585
parse_map: bool = False,
586586
step_size: float = 0.01,
587+
apply_proj: bool = True,
587588
**kwargs,
588589
) -> "Recording":
589590
"""Load a Recording from a file. Supports `.parquet`, `.osi` and `.mcap` files.
@@ -594,6 +595,7 @@ def from_file(
594595
validate (bool): Whether to validate the data against the schema.
595596
parse_map (bool): Whether to create python objects from the map data or just load it.
596597
step_size (float): Step size for map parsing, if applicable (Used for ASAM OpenDRIVE).
598+
apply_proj (bool): Whether to apply projection transformations to the recording's moving object data.
597599
598600
Returns:
599601
Recording (Recording): The loaded Recording object.
@@ -624,6 +626,12 @@ def from_file(
624626
r.map = map
625627
elif r.map is None:
626628
warn(f"No map could be found: {map_parsing}")
629+
630+
if r.projections and apply_proj:
631+
try:
632+
r.apply_projections()
633+
except Exception:
634+
warn("Failed to apply projections.")
627635
return r
628636

629637
def to_file(self, filepath):
@@ -824,28 +832,56 @@ def interpolate(self, new_nanos: list[int] | None = None, hz: float | None = Non
824832
new_df = pl.concat(new_dfs)
825833
return self.__init__(df=new_df, map=self.map, host_vehicle_idx=self.host_vehicle_idx)
826834

827-
def plot(self, ax=None, legend=False) -> plt.Axes:
835+
def _create_legend(self, ax):
836+
handles, labels = ax.get_legend_handles_labels()
837+
host_label = f"{self.host_vehicle_idx} - HV"
838+
839+
def sort_key(item):
840+
label = item[1]
841+
if label == host_label:
842+
return (-1, -1)
843+
try:
844+
return (0, int(label))
845+
except ValueError:
846+
return (0, float("inf")) # non-numeric labels go last
847+
848+
items = sorted(zip(handles, labels), key=sort_key)
849+
handles, labels = zip(*items)
850+
ax.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5))
851+
return ax
852+
853+
def plot(self, ax=None, legend=False, mvs_plt_type: str = "scatter") -> plt.Axes:
828854
"Generate a static plot of the recording using Matplotlib. Plots the map (if available), moving objects, and traffic light states."
829855
if ax is None:
830856
fig, ax = plt.subplots(1, 1)
831857
ax.set_aspect(1)
832858
if self.map:
833859
self.map.plot(ax)
834-
self.plot_mvs(ax=ax)
860+
self.plot_mvs(ax=ax, mvs_plt_type=mvs_plt_type)
835861
self.plot_tl(ax=ax)
836862
if legend:
837-
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
863+
ax = self._create_legend(ax)
838864
return ax
839865

840-
def plot_mvs(self, ax=None, legend=False):
866+
def plot_mvs(self, ax=None, legend=False, mvs_plt_type: str = "scatter"):
841867
"Generate a static plot of the moving objects in the recording using Matplotlib."
842868
if ax is None:
843869
fig, ax = plt.subplots(1, 1)
844870
ax.set_aspect(1)
845-
for [idx], mv in self._df["idx", "x", "y"].group_by("idx"):
846-
ax.plot(*mv["x", "y"], c="red", alpha=0.5, label=str(idx))
871+
plot_fn = {"scatter": ax.scatter, "plot": ax.plot}.get(mvs_plt_type)
872+
if plot_fn is None:
873+
raise ValueError("`mvs_plt_type` must be one of: 'scatter', 'plot'.")
874+
875+
plot_df = self._df["idx", "x", "y"]
876+
base_kwargs = {"alpha": 0.5}
877+
for [idx], mv in plot_df.group_by("idx"):
878+
if idx == self.host_vehicle_idx:
879+
ax.plot(*mv["x", "y"], c="red", label=f"{idx} - HV")
880+
continue
881+
plot_fn(*mv["x", "y"], label=str(idx), **base_kwargs)
882+
847883
if legend:
848-
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
884+
ax = self._create_legend(ax)
849885
return ax
850886

851887
def plot_tl(self, ax=None):
@@ -866,7 +902,7 @@ def plot_tl(self, ax=None):
866902
ax.plot(
867903
x,
868904
y,
869-
marker="o",
905+
marker="s",
870906
label=f"Traffic Light {tl_dict[tl].id.value}",
871907
c="blue",
872908
alpha=0.7,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ dependencies = [
4848
'tqdm_joblib',
4949
'filelock>=3.18.0'
5050
]
51-
version = "0.2.1"
51+
version = "0.2.2"
5252

5353
[project.urls]
5454
Homepage = "https://github.com/ika-rwth-aachen/omega-prime"

0 commit comments

Comments
 (0)