Skip to content

Commit c44347c

Browse files
committed
add SensorManager
1 parent ae58b2c commit c44347c

File tree

5 files changed

+115
-4
lines changed

5 files changed

+115
-4
lines changed

genesis/engine/scene.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
PBDOptions,
2626
ProfilingOptions,
2727
RigidOptions,
28+
SensorOptions,
2829
SFOptions,
2930
SimOptions,
3031
SPHOptions,
@@ -513,6 +514,10 @@ def add_light(
513514
else:
514515
gs.raise_exception("Adding lights is only supported by 'RayTracer' and 'BatchRenderer'.")
515516

517+
@gs.assert_unbuilt
518+
def add_sensor(self, sensor_options: SensorOptions):
519+
return self._sim._sensor_manager.create_sensor(sensor_options)
520+
516521
@gs.assert_unbuilt
517522
def add_camera(
518523
self,

genesis/engine/simulator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
)
3737
from .states.cache import QueriedStates
3838
from .states.solvers import SimState
39+
from genesis.sensors.sensor_manager import SensorManager
3940

4041
if TYPE_CHECKING:
4142
from genesis.engine.scene import Scene
@@ -151,6 +152,9 @@ def __init__(
151152
# entities
152153
self._entities: list[Entity] = gs.List()
153154

155+
# sensors
156+
self._sensor_manager = SensorManager()
157+
154158
def _add_entity(self, morph: Morph, material, surface, visualize_contact=False):
155159
if isinstance(material, gs.materials.Tool):
156160
entity = self.tool_solver.add_entity(self.n_entities, material, morph, surface)
@@ -206,6 +210,8 @@ def build(self):
206210
if self.n_envs > 0 and self.sf_solver.is_active():
207211
gs.raise_exception("Batching is not supported for SF solver as of now.")
208212

213+
self._sensor_manager.build()
214+
209215
# hybrid
210216
for entity in self._entities:
211217
if isinstance(entity, HybridEntity):

genesis/options/sensors.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from genesis.options import Options
2+
3+
4+
class SensorOptions(Options):
5+
"""
6+
Base class for all sensor options.
7+
Each sensor should have their own options class that inherits from this class.
8+
The options class should be registered with the SensorManager using the @register_sensor decorator.
9+
"""
10+
11+
pass

genesis/sensors/base_sensor.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
from typing import List, Optional
2+
13
import taichi as ti
2-
import genesis as gs
34

4-
from typing import List, Optional
5+
import genesis as gs
6+
from genesis.options.sensors import SensorOptions
57
from genesis.repr_base import RBC
68

9+
from .sensor_manager import SensorManager
10+
711

812
@ti.data_oriented
913
class Sensor(RBC):
@@ -12,10 +16,30 @@ class Sensor(RBC):
1216
A sensor must have a read() method that returns the sensor data.
1317
"""
1418

19+
def __init__(self, sensor_options: SensorOptions, sensor_idx: int, sensor_manager: SensorManager):
20+
self._options: SensorOptions = sensor_options
21+
self._idx: int = sensor_idx
22+
self._manager: SensorManager = sensor_manager
23+
self._cache_idx: int = -1 # cache_idx is set by the SensorManager during the scene build phase
24+
25+
@gs.assert_unbuilt
26+
def build(self):
27+
"""
28+
This method is called by the SensorManager during the scene build phase to initialize the sensor.
29+
"""
30+
pass
31+
1532
@gs.assert_built
1633
def read(self, envs_idx: Optional[List[int]] = None):
1734
"""
1835
Read the sensor data.
19-
Sensor implementations should ideally cache the data to avoid unnecessary computations.
36+
Sensor implementations should make use of the caching system located in SensorManager when possible.
37+
"""
38+
raise NotImplementedError("Sensors must implement `read()`.")
39+
40+
@property
41+
def cache_length(self) -> int:
42+
"""
43+
The length (first dimension of cache size) of the cache for this sensor.
2044
"""
21-
raise NotImplementedError("The Sensor subclass must implement `read()`.")
45+
raise NotImplementedError("Sensors must implement `cache_length()`.")

genesis/sensors/sensor_manager.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from typing import Type
2+
3+
import torch
4+
5+
from genesis.options.sensors import SensorOptions
6+
7+
from .base_sensor import Sensor
8+
9+
10+
class SensorManager:
11+
SENSOR_TYPES_MAP: dict[Type[SensorOptions], Type[Sensor]] = {}
12+
SENSOR_CACHE_METADATA_MAP: dict[Type[Sensor], (torch.dtype, tuple[int, ...])] = {}
13+
14+
def __init__(self):
15+
self.sensors_by_type: dict[Type[Sensor], list[Sensor]] = {}
16+
self.cache: dict[Type[Sensor], torch.Tensor] = {}
17+
self.cache_size_map: dict[Type[Sensor], int] = {}
18+
19+
def create_sensor(self, sensor_options: SensorOptions):
20+
sensor_cls = SensorManager.SENSOR_TYPES_MAP[type(sensor_options)]
21+
sensor = sensor_cls(sensor_options, len(self.sensors_by_type[sensor_cls]), self)
22+
if sensor_cls not in self.sensors_by_type:
23+
self.sensors_by_type[sensor_cls] = []
24+
self.sensors_by_type[sensor_cls].append(sensor)
25+
return sensor
26+
27+
def build(self):
28+
for sensor_cls, sensors in self.sensors_by_type.items():
29+
total_cache_length = 0
30+
for sensor in sensors:
31+
sensor.build()
32+
sensor._cache_idx = total_cache_length
33+
total_cache_length += sensor.cache_length
34+
35+
cache_dtype, cache_shape = SensorManager.SENSOR_CACHE_METADATA_MAP[sensor_cls]
36+
self.cache[sensor_cls] = torch.zeros((total_cache_length, *cache_shape), dtype=cache_dtype)
37+
38+
for sensor in sensors:
39+
sensor.build()
40+
41+
def get_sensor_cache(self, sensor_cls: Type[Sensor], sensor_idx: int | None = None) -> torch.Tensor:
42+
cache_size = SensorManager.SENSOR_CACHE_SIZE_MAP[sensor_cls]
43+
if sensor_idx is None:
44+
return self.cache[sensor_cls]
45+
return self.cache[sensor_cls][sensor_idx * cache_size : (sensor_idx + 1) * cache_size]
46+
47+
def set_sensor_cache(self, new_values: torch.Tensor, sensor_cls: Type[Sensor], sensor_idx: int | None = None):
48+
cache_size = SensorManager.SENSOR_CACHE_SIZE_MAP[sensor_cls]
49+
if sensor_idx is None:
50+
self.cache[sensor_cls] = new_values
51+
else:
52+
self.cache[sensor_cls][sensor_idx * cache_size : (sensor_idx + 1) * cache_size] = new_values
53+
54+
@property
55+
def sensors(self):
56+
return [sensor for sensor_list in self.sensors_by_type.values() for sensor in sensor_list]
57+
58+
59+
def register_sensor(sensor_cls: Type[Sensor], cache_dtype: torch.dtype, cache_shape: tuple[int, ...]):
60+
def _impl(sensor_options: SensorOptions):
61+
SensorManager.SENSOR_TYPES_MAP[type(sensor_options)] = sensor_cls
62+
SensorManager.SENSOR_CACHE_METADATA_MAP[sensor_cls] = (cache_dtype, cache_shape)
63+
return sensor_options
64+
65+
return _impl

0 commit comments

Comments
 (0)