Skip to content

Commit 64fe3e3

Browse files
authored
[FEATURE] Add dedicated sensor manager. (Genesis-Embodied-AI#1518)
1 parent 38fcf9b commit 64fe3e3

File tree

8 files changed

+346
-12
lines changed

8 files changed

+346
-12
lines changed

genesis/engine/scene.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
PBDOptions,
2727
ProfilingOptions,
2828
RigidOptions,
29+
SensorOptions,
2930
SFOptions,
3031
SimOptions,
3132
SPHOptions,
@@ -514,6 +515,10 @@ def add_light(
514515
else:
515516
gs.raise_exception("Adding lights is only supported by 'RayTracer' and 'BatchRenderer'.")
516517

518+
@gs.assert_unbuilt
519+
def add_sensor(self, sensor_options: SensorOptions):
520+
return self._sim._sensor_manager.create_sensor(sensor_options)
521+
517522
@gs.assert_unbuilt
518523
def add_camera(
519524
self,

genesis/engine/simulator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
)
3737
from .states.cache import QueriedStates
3838
from .states.solvers import SimState
39+
from genesis.sensors.sensor_manager import SensorManager
3940

4041
if TYPE_CHECKING:
4142
from genesis.engine.scene import Scene
@@ -151,6 +152,9 @@ def __init__(
151152
# entities
152153
self._entities: list[Entity] = gs.List()
153154

155+
# sensors
156+
self._sensor_manager = SensorManager(self)
157+
154158
def _add_entity(self, morph: Morph, material, surface, visualize_contact=False):
155159
if isinstance(material, gs.materials.Tool):
156160
entity = self.tool_solver.add_entity(self.n_entities, material, morph, surface)
@@ -206,6 +210,8 @@ def build(self):
206210
if self.n_envs > 0 and self.sf_solver.is_active():
207211
gs.raise_exception("Batching is not supported for SF solver as of now.")
208212

213+
self._sensor_manager.build()
214+
209215
# hybrid
210216
for entity in self._entities:
211217
if isinstance(entity, HybridEntity):
@@ -275,6 +281,8 @@ def step(self, in_backward=False):
275281
if self.rigid_solver.is_active():
276282
self.rigid_solver.clear_external_force()
277283

284+
self._sensor_manager.step()
285+
278286
def _step_grad(self):
279287
for _ in range(self._substeps - 1, -1, -1):
280288

genesis/options/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
from .solvers import *
33
from .vis import *
44
from .profiling import ProfilingOptions
5-
5+
from .sensors import SensorOptions
66

77
__all__ = ["ProfilingOptions"]

genesis/options/sensors.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from genesis.options import Options
2+
3+
4+
class SensorOptions(Options):
5+
"""
6+
Base class for all sensor options.
7+
Each sensor should have their own options class that inherits from this class.
8+
The options class should be registered with the SensorManager using the @register_sensor decorator.
9+
10+
Parameters
11+
----------
12+
read_delay : float
13+
The delay in seconds before the sensor data is read.
14+
"""
15+
16+
read_delay: float = 0.0

genesis/sensors/base_sensor.py

Lines changed: 116 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,130 @@
1+
from typing import TYPE_CHECKING, Any, List, Type
2+
3+
import numpy as np
14
import taichi as ti
2-
import genesis as gs
5+
import torch
36

4-
from typing import List, Optional
7+
import genesis as gs
58
from genesis.repr_base import RBC
69

10+
if TYPE_CHECKING:
11+
from genesis.options.sensors import SensorOptions
12+
from genesis.utils.ring_buffer import TensorRingBuffer
13+
14+
from .sensor_manager import SensorManager
15+
716

817
@ti.data_oriented
918
class Sensor(RBC):
1019
"""
1120
Base class for all types of sensors.
12-
A sensor must have a read() method that returns the sensor data.
1321
"""
1422

23+
def __init__(self, sensor_options: "SensorOptions", sensor_idx: int, sensor_manager: "SensorManager"):
24+
self._options: "SensorOptions" = sensor_options
25+
self._idx: int = sensor_idx
26+
self._manager: "SensorManager" = sensor_manager
27+
28+
# initialized by SensorManager during build
29+
self._read_delay_steps: int = 0
30+
self._shape_indices: list[tuple[int, int]] = []
31+
self._shared_metadata: dict[str, Any] | None = None
32+
self._cache: "TensorRingBuffer" | None = None
33+
34+
# =============================== implementable methods ===============================
35+
36+
def build(self):
37+
"""
38+
This method is called by the SensorManager during the scene build phase to initialize the sensor.
39+
This is where any shared metadata should be initialized.
40+
"""
41+
raise NotImplementedError("Sensors must implement `build()`.")
42+
43+
def _get_return_format(self) -> dict[str, tuple[int, ...]] | tuple[int, ...]:
44+
"""
45+
Data format of the read() return value.
46+
47+
Returns
48+
-------
49+
return_format : dict | tuple
50+
- If tuple, the final shape of the read() return value.
51+
e.g. (2, 3) means read() will return a tensor of shape (2, 3).
52+
- If dict a dictionary with string keys and tensor values will be returned.
53+
e.g. {"pos": (3,), "quat": (4,)} returns a dict of tensors [0:3] and [3:7] from the cache.
54+
"""
55+
raise NotImplementedError("Sensors must implement `return_format()`.")
56+
57+
def _get_cache_length(self) -> int:
58+
"""
59+
The length of the cache for this sensor instance, e.g. number of points for a Lidar point cloud.
60+
"""
61+
raise NotImplementedError("Sensors must implement `cache_length()`.")
62+
63+
@classmethod
64+
def _update_shared_ground_truth_cache(
65+
cls, shared_metadata: dict[str, Any], shared_ground_truth_cache: torch.Tensor
66+
):
67+
"""
68+
Update the shared sensor ground truth cache for all sensors of this class using metadata in SensorManager.
69+
"""
70+
raise NotImplementedError("Sensors must implement `update_shared_ground_truth_cache()`.")
71+
72+
@classmethod
73+
def _update_shared_cache(
74+
cls, shared_metadata: dict[str, Any], shared_ground_truth_cache: torch.Tensor, shared_cache: "TensorRingBuffer"
75+
):
76+
"""
77+
Update the shared sensor cache for all sensors of this class using metadata in SensorManager.
78+
"""
79+
raise NotImplementedError("Sensors must implement `update_shared_cache()`.")
80+
81+
@classmethod
82+
def _get_cache_dtype(cls) -> torch.dtype:
83+
"""
84+
The dtype of the cache for this sensor.
85+
"""
86+
raise NotImplementedError("Sensors must implement `get_cache_dtype()`.")
87+
88+
# =============================== shared methods ===============================
89+
1590
@gs.assert_built
16-
def read(self, envs_idx: Optional[List[int]] = None):
91+
def read(self, envs_idx: List[int] | None = None):
1792
"""
18-
Read the sensor data.
19-
Sensor implementations should ideally cache the data to avoid unnecessary computations.
93+
Read the sensor data (with noise applied if applicable).
2094
"""
21-
raise NotImplementedError("The Sensor subclass must implement `read()`.")
95+
return self._get_formatted_data(self._cache.get(self._read_delay_steps), envs_idx)
96+
97+
@gs.assert_built
98+
def read_ground_truth(self, envs_idx: List[int] | None = None):
99+
"""
100+
Read the ground truth sensor data (without noise).
101+
"""
102+
return self._get_formatted_data(self._manager.get_cloned_from_ground_truth_cache(self), envs_idx)
103+
104+
def _get_formatted_data(
105+
self, tensor: torch.Tensor, envs_idx: list[int] | None
106+
) -> torch.Tensor | dict[str, torch.Tensor]:
107+
# Note: This method does not clone the data tensor, it should have been cloned by the caller.
108+
109+
if envs_idx is None:
110+
envs_idx = self._manager._sim._scene._envs_idx
111+
112+
return_format = self._get_return_format()
113+
return_shapes = return_format.values() if isinstance(return_format, dict) else (return_format,)
114+
return_values = []
115+
116+
for i, shape in enumerate(return_shapes):
117+
start_idx, end_idx = self._shape_indices[i]
118+
value = tensor[envs_idx, start_idx:end_idx].reshape(len(envs_idx), *shape).squeeze()
119+
if self._manager._sim.n_envs == 0:
120+
value = value.squeeze(0)
121+
return_values.append(value)
122+
123+
if isinstance(return_format, dict):
124+
return dict(zip(return_format.keys(), return_values))
125+
else:
126+
return return_values[0]
127+
128+
@property
129+
def is_built(self) -> bool:
130+
return self._manager._sim._scene._is_built

genesis/sensors/sensor_manager.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from typing import TYPE_CHECKING, Any, Type
2+
3+
import numpy as np
4+
import torch
5+
6+
import genesis as gs
7+
from genesis.utils.ring_buffer import TensorRingBuffer
8+
9+
if TYPE_CHECKING:
10+
from genesis.options.sensors import SensorOptions
11+
12+
from .base_sensor import Sensor
13+
14+
15+
class SensorManager:
16+
SENSOR_TYPES_MAP: dict[Type["SensorOptions"], Type["Sensor"]] = {}
17+
18+
def __init__(self, sim):
19+
self._sim = sim
20+
self._sensors_by_type: dict[Type["Sensor"], list["Sensor"]] = {}
21+
self._sensors_metadata: dict[Type["Sensor"], dict[str, Any]] = {}
22+
self._ground_truth_cache: dict[Type[torch.dtype], torch.Tensor] = {}
23+
self._cache: dict[Type[torch.dtype], TensorRingBuffer] = {}
24+
self._cache_slices_by_type: dict[Type["Sensor"], slice] = {}
25+
26+
self._last_ground_truth_cache_cloned_step: dict[Type[torch.dtype], int] = {}
27+
self._cloned_ground_truth_cache: dict[Type[torch.dtype], torch.Tensor] = {}
28+
29+
def create_sensor(self, sensor_options: "SensorOptions"):
30+
sensor_cls = SensorManager.SENSOR_TYPES_MAP[type(sensor_options)]
31+
self._sensors_by_type.setdefault(sensor_cls, [])
32+
sensor = sensor_cls(sensor_options, len(self._sensors_by_type[sensor_cls]), self)
33+
self._sensors_by_type[sensor_cls].append(sensor)
34+
return sensor
35+
36+
def build(self):
37+
max_cache_buf_len = 0
38+
cache_size_per_dtype = {}
39+
for sensor_cls, sensors in self._sensors_by_type.items():
40+
self._sensors_metadata[sensor_cls] = {}
41+
dtype = sensor_cls._get_cache_dtype()
42+
cache_size_per_dtype.setdefault(dtype, 0)
43+
cls_cache_start_idx = cache_size_per_dtype[dtype]
44+
45+
for sensor in sensors:
46+
return_format = sensor._get_return_format()
47+
return_shapes = return_format.values() if isinstance(return_format, dict) else (return_format,)
48+
49+
tensor_size = 0
50+
for shape in return_shapes:
51+
data_size = np.prod(shape)
52+
sensor._shape_indices.append((tensor_size, tensor_size + data_size))
53+
tensor_size += data_size
54+
55+
delay_steps_float = sensor._options.read_delay / self._sim.dt
56+
sensor._read_delay_steps = round(delay_steps_float)
57+
if not np.isclose(delay_steps_float, sensor._read_delay_steps, atol=1e-6):
58+
gs.logger.warn(
59+
f"Read delay should be a multiple of the simulation time step. Got {sensor._options.read_delay}"
60+
f" and {self._sim.dt}. Actual read delay will be {1/sensor._read_delay_steps}."
61+
)
62+
63+
sensor._cache_size = sensor._get_cache_length() * tensor_size
64+
sensor._cache_idx = cache_size_per_dtype[dtype]
65+
cache_size_per_dtype[dtype] += sensor._cache_size
66+
67+
max_cache_buf_len = max(max_cache_buf_len, sensor._read_delay_steps + 1)
68+
69+
cls_cache_end_idx = cache_size_per_dtype[dtype]
70+
self._cache_slices_by_type[sensor_cls] = slice(cls_cache_start_idx, cls_cache_end_idx)
71+
72+
for dtype in cache_size_per_dtype.keys():
73+
cache_shape = (self._sim._B, cache_size_per_dtype[dtype])
74+
self._ground_truth_cache[dtype] = torch.zeros(cache_shape, dtype=dtype)
75+
self._cache[dtype] = TensorRingBuffer(max_cache_buf_len, cache_shape, dtype=dtype)
76+
77+
for sensor_cls, sensors in self._sensors_by_type.items():
78+
dtype = sensor_cls._get_cache_dtype()
79+
for sensor in sensors:
80+
sensor._shared_metadata = self._sensors_metadata[sensor_cls]
81+
sensor._cache = self._cache[dtype][:, sensor._cache_idx : sensor._cache_idx + sensor._cache_size]
82+
sensor.build()
83+
84+
def step(self):
85+
for sensor_cls in self._sensors_by_type.keys():
86+
dtype = sensor_cls._get_cache_dtype()
87+
cache_slice = self._cache_slices_by_type[sensor_cls]
88+
sensor_cls._update_shared_ground_truth_cache(
89+
self._sensors_metadata[sensor_cls], self._ground_truth_cache[dtype][cache_slice]
90+
)
91+
sensor_cls._update_shared_cache(
92+
self._sensors_metadata[sensor_cls],
93+
self._ground_truth_cache[dtype][cache_slice],
94+
self._cache[dtype][cache_slice],
95+
)
96+
97+
def get_cloned_from_ground_truth_cache(self, sensor: "Sensor") -> torch.Tensor:
98+
dtype = sensor._get_cache_dtype()
99+
if self._last_ground_truth_cache_cloned_step[dtype] != self._sim.cur_step_global:
100+
self._last_ground_truth_cache_cloned_step[dtype] = self._sim.cur_step_global
101+
self._cloned_ground_truth_cache[dtype] = self._ground_truth_cache[dtype].clone()
102+
return self._cloned_ground_truth_cache[dtype][:, sensor._cache_idx : sensor._cache_idx + sensor._cache_size]
103+
104+
@property
105+
def sensors(self):
106+
return tuple([sensor for sensor_list in self._sensors_by_type.values() for sensor in sensor_list])
107+
108+
109+
def register_sensor(sensor_cls: Type["Sensor"]):
110+
def _impl(options_cls: Type["SensorOptions"]):
111+
SensorManager.SENSOR_TYPES_MAP[options_cls] = sensor_cls
112+
return options_cls
113+
114+
return _impl

0 commit comments

Comments
 (0)