@@ -23,7 +23,7 @@ def __init__(self, sim):
2323 self ._cache : dict [Type [torch .dtype ], TensorRingBuffer ] = {}
2424 self ._cache_slices_by_type : dict [Type ["Sensor" ], slice ] = {}
2525
26- self ._last_ground_truth_cache_cloned_step : int = - 1
26+ self ._last_ground_truth_cache_cloned_step : dict [ Type [ torch . dtype ], int ] = {}
2727 self ._cloned_ground_truth_cache : dict [Type [torch .dtype ], torch .Tensor ] = {}
2828
2929 def create_sensor (self , sensor_options : "SensorOptions" ):
@@ -78,7 +78,7 @@ def build(self):
7878 dtype = sensor_cls ._get_cache_dtype ()
7979 for sensor in sensors :
8080 sensor ._shared_metadata = self ._sensors_metadata [sensor_cls ]
81- sensor ._cache = self ._cache [dtype ][sensor ._cache_idx : sensor ._cache_idx + sensor ._cache_size ]
81+ sensor ._cache = self ._cache [dtype ][:, sensor ._cache_idx : sensor ._cache_idx + sensor ._cache_size ]
8282 sensor .build ()
8383
8484 def step (self ):
@@ -96,8 +96,8 @@ def step(self):
9696
9797 def get_cloned_from_ground_truth_cache (self , sensor : "Sensor" ) -> torch .Tensor :
9898 dtype = sensor ._get_cache_dtype ()
99- if self ._last_ground_truth_cache_cloned_step != self ._sim .cur_step_global :
100- self ._last_ground_truth_cache_cloned_step = self ._sim .cur_step_global
99+ if self ._last_ground_truth_cache_cloned_step [ dtype ] != self ._sim .cur_step_global :
100+ self ._last_ground_truth_cache_cloned_step [ dtype ] = self ._sim .cur_step_global
101101 self ._cloned_ground_truth_cache [dtype ] = self ._ground_truth_cache [dtype ].clone ()
102102 return self ._cloned_ground_truth_cache [dtype ][:, sensor ._cache_idx : sensor ._cache_idx + sensor ._cache_size ]
103103
0 commit comments