Skip to content

Commit d89f2f6

Browse files
committed
no cloned_cache
1 parent 925fb0d commit d89f2f6

File tree

3 files changed

+8
-20
lines changed

3 files changed

+8
-20
lines changed

genesis/sensors/base_sensor.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(self, sensor_options: "SensorOptions", sensor_idx: int, sensor_mana
2929
self._read_delay_steps: int = 0
3030
self._shape_indices: list[tuple[int, int]] = []
3131
self._shared_metadata: dict[str, Any] | None = None
32+
self._cache: "TensorRingBuffer" | None = None
3233

3334
# =============================== implementable methods ===============================
3435

@@ -91,10 +92,7 @@ def read(self, envs_idx: List[int] | None = None):
9192
"""
9293
Read the sensor data (with noise applied if applicable).
9394
"""
94-
return self._get_formatted_data(
95-
self._manager.get_cloned_from_cache(self).get(self._read_delay_steps),
96-
envs_idx,
97-
)
95+
return self._get_formatted_data(self._cache.get(self._read_delay_steps), envs_idx)
9896

9997
@gs.assert_built
10098
def read_ground_truth(self, envs_idx: List[int] | None = None):

genesis/sensors/sensor_manager.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@ def __init__(self, sim):
2323
self._cache: dict[Type[torch.dtype], TensorRingBuffer] = {}
2424
self._cache_slices_by_type: dict[Type["Sensor"], slice] = {}
2525

26-
self._last_cache_cloned_step: int = -1
2726
self._last_ground_truth_cache_cloned_step: int = -1
28-
self._cloned_cache: dict[Type[torch.dtype], TensorRingBuffer] = {}
2927
self._cloned_ground_truth_cache: dict[Type[torch.dtype], torch.Tensor] = {}
3028

3129
def create_sensor(self, sensor_options: "SensorOptions"):
@@ -80,11 +78,7 @@ def build(self):
8078
dtype = sensor_cls._get_cache_dtype()
8179
for sensor in sensors:
8280
sensor._shared_metadata = self._sensors_metadata[sensor_cls]
83-
84-
cache_slice = slice(sensor._cache_idx, sensor._cache_idx + sensor._cache_size)
85-
sensor._cache = self._cache[dtype][cache_slice]
86-
sensor._ground_truth_cache = self._ground_truth_cache[dtype][cache_slice]
87-
81+
sensor._cache = self._cache[dtype][sensor._cache_idx : sensor._cache_idx + sensor._cache_size]
8882
sensor.build()
8983

9084
def step(self):
@@ -100,13 +94,6 @@ def step(self):
10094
self._cache[dtype][cache_slice],
10195
)
10296

103-
def get_cloned_from_cache(self, sensor: "Sensor") -> "TensorRingBuffer":
104-
dtype = sensor._get_cache_dtype()
105-
if self._last_cache_cloned_step != self._sim.cur_step_global:
106-
self._last_cache_cloned_step = self._sim.cur_step_global
107-
self._cloned_cache[dtype] = self._cache[dtype].clone()
108-
return self._cloned_cache[dtype][:, sensor._cache_idx : sensor._cache_idx + sensor._cache_size]
109-
11097
def get_cloned_from_ground_truth_cache(self, sensor: "Sensor") -> torch.Tensor:
11198
dtype = sensor._get_cache_dtype()
11299
if self._last_ground_truth_cache_cloned_step != self._sim.cur_step_global:

genesis/utils/ring_buffer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,17 @@ def append(self, tensor: torch.Tensor):
2929
self.buffer[self._idx_ptr.value].copy_(tensor)
3030
self._idx_ptr.value = (self._idx_ptr.value + 1) % self.N
3131

32-
def get(self, idx: int):
32+
def get(self, idx: int, clone: bool = True):
3333
"""
3434
Parameters
3535
----------
3636
idx : int
3737
Index of the element to get, where 0 is the latest element, 1 is the second latest, etc.
38+
clone : bool
39+
Whether to clone the tensor.
3840
"""
39-
return self.buffer[(self._idx_ptr.value - idx) % self.N]
41+
tensor = self.buffer[(self._idx_ptr.value - idx) % self.N]
42+
return tensor.clone() if clone else tensor
4043

4144
def clone(self):
4245
return TensorRingBuffer(

0 commit comments

Comments
 (0)