@@ -288,7 +288,7 @@ def _add_polygons_to_df(self):
288288
289289 @staticmethod
290290 def get_moving_object_ground_truth (
291- nanos : int , df : pl .DataFrame , host_vehicle = None , validate = False
291+ nanos : int , df : pl .DataFrame , host_vehicle_idx : int | None = None , validate : bool = False
292292 ) -> betterosi .GroundTruth :
293293 if validate :
294294 recording_moving_object_schema .validate (df , lazy = True )
@@ -314,13 +314,21 @@ def get_object(row):
314314 version = betterosi .InterfaceVersion (version_major = 3 , version_minor = 7 , version_patch = 9 ),
315315 timestamp = betterosi .Timestamp (seconds = int (nanos // 1_000_000_000 ), nanos = int (nanos % 1_000_000_000 )),
316316 host_vehicle_id = betterosi .Identifier (value = 0 )
317- if host_vehicle is None
318- else betterosi .Identifier (value = host_vehicle ),
317+ if host_vehicle_idx is None
318+ else betterosi .Identifier (value = host_vehicle_idx ),
319319 moving_object = mvs ,
320320 )
321321 return gt
322322
323- def __init__ (self , df , map = None , projections = None , host_vehicle = None , validate = False , compute_polygons = False ):
323+ def __init__ (
324+ self ,
325+ df ,
326+ map = None ,
327+ projections = None ,
328+ host_vehicle_idx : int | None = None ,
329+ validate = False ,
330+ compute_polygons = False ,
331+ ):
324332 if not isinstance (df , pl .DataFrame ):
325333 df = pl .DataFrame (df )
326334 if validate :
@@ -353,7 +361,11 @@ def __init__(self, df, map=None, projections=None, host_vehicle=None, validate=F
353361 self ._moving_objects = (
354362 None # = {int(idx): self._MovingObjectClass(self, idx) for idx in self._df["idx"].unique()}
355363 )
356- self .host_vehicle = host_vehicle
364+ self .host_vehicle_idx = host_vehicle_idx
365+
366+ @property
367+ def host_vehicle (self ):
368+ return self .moving_objects .get (self .host_vehicle_idx , None )
357369
358370 @property
359371 def moving_objects (self ):
@@ -364,7 +376,9 @@ def moving_objects(self):
364376 def to_osi_gts (self ) -> list [betterosi .GroundTruth ]:
365377 first_iteration = True
366378 for [nanos ], group_df in self ._df .sort (["total_nanos" ]).group_by ("total_nanos" , maintain_order = True ):
367- gt = self .get_moving_object_ground_truth (nanos , group_df , host_vehicle = self .host_vehicle , validate = False )
379+ gt = self .get_moving_object_ground_truth (
380+ nanos , group_df , host_vehicle_idx = self .host_vehicle_idx , validate = False
381+ )
368382 if first_iteration :
369383 first_iteration = False
370384 if self .map is not None and isinstance (self .map , MapOsi | MapOsiCenterline ):
@@ -485,7 +499,7 @@ def to_hdf(self, filename, key="moving_object"):
485499 @classmethod
486500 def from_hdf (cls , filename , key = "moving_object" ):
487501 df = pl .DataFrame (pd .read_hdf (filename , key = key ))
488- return cls (df , map = None , host_vehicle = None )
502+ return cls (df , map = None , host_vehicle_idx = None )
489503
490504 def interpolate (self , new_nanos : list [int ] | None = None , hz : float | None = None ):
491505 df = self ._df
@@ -521,7 +535,7 @@ def interpolate(self, new_nanos: list[int] | None = None, hz: float | None = Non
521535 )
522536 new_dfs .append (new_track_df )
523537 new_df = pl .concat (new_dfs )
524- return self .__init__ (new_df , self .map , self .host_vehicle )
538+ return self .__init__ (new_df , self .map , self .host_vehicle_idx )
525539
526540 def plot (self , ax = None , legend = False ) -> plt .Axes :
527541 if ax is None :
0 commit comments