Skip to content

Commit e72688e

Browse files
committed
fix cache indexing and ground_truth_cache_cloned
1 parent d89f2f6 commit e72688e

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

genesis/sensors/sensor_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, sim):
2323
self._cache: dict[Type[torch.dtype], TensorRingBuffer] = {}
2424
self._cache_slices_by_type: dict[Type["Sensor"], slice] = {}
2525

26-
self._last_ground_truth_cache_cloned_step: int = -1
26+
self._last_ground_truth_cache_cloned_step: dict[Type[torch.dtype], int] = {}
2727
self._cloned_ground_truth_cache: dict[Type[torch.dtype], torch.Tensor] = {}
2828

2929
def create_sensor(self, sensor_options: "SensorOptions"):
@@ -78,7 +78,7 @@ def build(self):
7878
dtype = sensor_cls._get_cache_dtype()
7979
for sensor in sensors:
8080
sensor._shared_metadata = self._sensors_metadata[sensor_cls]
81-
sensor._cache = self._cache[dtype][sensor._cache_idx : sensor._cache_idx + sensor._cache_size]
81+
sensor._cache = self._cache[dtype][:, sensor._cache_idx : sensor._cache_idx + sensor._cache_size]
8282
sensor.build()
8383

8484
def step(self):
@@ -96,8 +96,8 @@ def step(self):
9696

9797
def get_cloned_from_ground_truth_cache(self, sensor: "Sensor") -> torch.Tensor:
9898
dtype = sensor._get_cache_dtype()
99-
if self._last_ground_truth_cache_cloned_step != self._sim.cur_step_global:
100-
self._last_ground_truth_cache_cloned_step = self._sim.cur_step_global
99+
if self._last_ground_truth_cache_cloned_step[dtype] != self._sim.cur_step_global:
100+
self._last_ground_truth_cache_cloned_step[dtype] = self._sim.cur_step_global
101101
self._cloned_ground_truth_cache[dtype] = self._ground_truth_cache[dtype].clone()
102102
return self._cloned_ground_truth_cache[dtype][:, sensor._cache_idx : sensor._cache_idx + sensor._cache_size]
103103

0 commit comments

Comments
 (0)