Skip to content

Commit c893cb5

Browse files
rename host_vehicle to host_vehicle_idx
1 parent 1912414 commit c893cb5

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

omega_prime/recording.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def _add_polygons_to_df(self):
288288

289289
@staticmethod
290290
def get_moving_object_ground_truth(
291-
nanos: int, df: pl.DataFrame, host_vehicle=None, validate=False
291+
nanos: int, df: pl.DataFrame, host_vehicle_idx: int | None = None, validate: bool = False
292292
) -> betterosi.GroundTruth:
293293
if validate:
294294
recording_moving_object_schema.validate(df, lazy=True)
@@ -314,13 +314,21 @@ def get_object(row):
314314
version=betterosi.InterfaceVersion(version_major=3, version_minor=7, version_patch=9),
315315
timestamp=betterosi.Timestamp(seconds=int(nanos // 1_000_000_000), nanos=int(nanos % 1_000_000_000)),
316316
host_vehicle_id=betterosi.Identifier(value=0)
317-
if host_vehicle is None
318-
else betterosi.Identifier(value=host_vehicle),
317+
if host_vehicle_idx is None
318+
else betterosi.Identifier(value=host_vehicle_idx),
319319
moving_object=mvs,
320320
)
321321
return gt
322322

323-
def __init__(self, df, map=None, projections=None, host_vehicle=None, validate=False, compute_polygons=False):
323+
def __init__(
324+
self,
325+
df,
326+
map=None,
327+
projections=None,
328+
host_vehicle_idx: int | None = None,
329+
validate=False,
330+
compute_polygons=False,
331+
):
324332
if not isinstance(df, pl.DataFrame):
325333
df = pl.DataFrame(df)
326334
if validate:
@@ -353,7 +361,11 @@ def __init__(self, df, map=None, projections=None, host_vehicle=None, validate=F
353361
self._moving_objects = (
354362
None # = {int(idx): self._MovingObjectClass(self, idx) for idx in self._df["idx"].unique()}
355363
)
356-
self.host_vehicle = host_vehicle
364+
self.host_vehicle_idx = host_vehicle_idx
365+
366+
@property
367+
def host_vehicle(self):
368+
return self.moving_objects.get(self.host_vehicle_idx, None)
357369

358370
@property
359371
def moving_objects(self):
@@ -364,7 +376,9 @@ def moving_objects(self):
364376
def to_osi_gts(self) -> list[betterosi.GroundTruth]:
365377
first_iteration = True
366378
for [nanos], group_df in self._df.sort(["total_nanos"]).group_by("total_nanos", maintain_order=True):
367-
gt = self.get_moving_object_ground_truth(nanos, group_df, host_vehicle=self.host_vehicle, validate=False)
379+
gt = self.get_moving_object_ground_truth(
380+
nanos, group_df, host_vehicle_idx=self.host_vehicle_idx, validate=False
381+
)
368382
if first_iteration:
369383
first_iteration = False
370384
if self.map is not None and isinstance(self.map, MapOsi | MapOsiCenterline):
@@ -485,7 +499,7 @@ def to_hdf(self, filename, key="moving_object"):
485499
@classmethod
486500
def from_hdf(cls, filename, key="moving_object"):
487501
df = pl.DataFrame(pd.read_hdf(filename, key=key))
488-
return cls(df, map=None, host_vehicle=None)
502+
return cls(df, map=None, host_vehicle_idx=None)
489503

490504
def interpolate(self, new_nanos: list[int] | None = None, hz: float | None = None):
491505
df = self._df
@@ -521,7 +535,7 @@ def interpolate(self, new_nanos: list[int] | None = None, hz: float | None = Non
521535
)
522536
new_dfs.append(new_track_df)
523537
new_df = pl.concat(new_dfs)
524-
return self.__init__(new_df, self.map, self.host_vehicle)
538+
return self.__init__(new_df, self.map, self.host_vehicle_idx)
525539

526540
def plot(self, ax=None, legend=False) -> plt.Axes:
527541
if ax is None:

0 commit comments

Comments
 (0)