Skip to content

Commit 05c0fc5

Browse files
committed
fix sensor manager metadata
1 parent c44347c commit 05c0fc5

File tree

4 files changed

+70
-39
lines changed

4 files changed

+70
-39
lines changed

genesis/engine/simulator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def __init__(
153153
self._entities: list[Entity] = gs.List()
154154

155155
# sensors
156-
self._sensor_manager = SensorManager()
156+
self._sensor_manager = SensorManager(self)
157157

158158
def _add_entity(self, morph: Morph, material, surface, visualize_contact=False):
159159
if isinstance(material, gs.materials.Tool):

genesis/options/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
from .solvers import *
33
from .vis import *
44
from .profiling import ProfilingOptions
5-
5+
from .sensors import SensorOptions
66

77
__all__ = ["ProfilingOptions"]

genesis/sensors/base_sensor.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import List, Optional
1+
from typing import Any, List, Optional
22

33
import taichi as ti
4+
import torch
45

56
import genesis as gs
67
from genesis.options.sensors import SensorOptions
@@ -16,12 +17,19 @@ class Sensor(RBC):
1617
A sensor must have a read() method that returns the sensor data.
1718
"""
1819

20+
# These class variables are used by SensorManager to determine the cache metadata for the sensor.
21+
# Sensor implementations should override these class variable values.
22+
CACHE_DTYPE: torch.dtype = torch.float32
23+
CACHE_SHAPE: tuple[int, ...] = (1,)
24+
1925
def __init__(self, sensor_options: SensorOptions, sensor_idx: int, sensor_manager: SensorManager):
2026
self._options: SensorOptions = sensor_options
2127
self._idx: int = sensor_idx
2228
self._manager: SensorManager = sensor_manager
2329
self._cache_idx: int = -1 # cache_idx is set by the SensorManager during the scene build phase
2430

31+
# =============================== implementable methods ===============================
32+
2533
@gs.assert_unbuilt
2634
def build(self):
2735
"""
@@ -33,7 +41,6 @@ def build(self):
3341
def read(self, envs_idx: Optional[List[int]] = None):
3442
"""
3543
Read the sensor data.
36-
Sensor implementations should make use of the caching system located in SensorManager when possible.
3744
"""
3845
raise NotImplementedError("Sensors must implement `read()`.")
3946

@@ -42,4 +49,30 @@ def cache_length(self) -> int:
4249
"""
4350
The length (first dimension of cache size) of the cache for this sensor.
4451
"""
45-
raise NotImplementedError("Sensors must implement `cache_length()`.")
52+
return 1
53+
54+
# =============================== shared methods ===============================
55+
56+
@property
57+
def is_built(self) -> bool:
58+
return self._manager._sim._scene._is_built
59+
60+
@gs.assert_built
61+
def _get_cache(self) -> torch.Tensor:
62+
return self._manager.get_sensor_cache(self.__class__, self._cache_idx)
63+
64+
@gs.assert_built
65+
def _is_cache_updated(self) -> bool:
66+
return self._manager.is_cache_updated(self.__class__)
67+
68+
@gs.assert_built
69+
def _set_cache_updated(self):
70+
self._manager.set_cache_updated(self.__class__)
71+
72+
@property
73+
def _cache(self) -> torch.Tensor:
74+
return self._manager._cache[self.__class__]
75+
76+
@property
77+
def _shared_metadata(self) -> dict[str, Any]:
78+
return self._manager._sensors_metadata[self.__class__]

genesis/sensors/sensor_manager.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,63 @@
1-
from typing import Type
1+
from typing import Any, Type
22

33
import torch
44

55
from genesis.options.sensors import SensorOptions
66

7-
from .base_sensor import Sensor
7+
from typing import TYPE_CHECKING
8+
9+
if TYPE_CHECKING:
10+
from .base_sensor import Sensor
811

912

1013
class SensorManager:
11-
SENSOR_TYPES_MAP: dict[Type[SensorOptions], Type[Sensor]] = {}
12-
SENSOR_CACHE_METADATA_MAP: dict[Type[Sensor], (torch.dtype, tuple[int, ...])] = {}
14+
SENSOR_TYPES_MAP: dict[Type[SensorOptions], Type["Sensor"]] = {}
15+
SENSOR_CACHE_METADATA_MAP: dict[Type["Sensor"], (torch.dtype, tuple[int, ...])] = {}
1316

14-
def __init__(self):
15-
self.sensors_by_type: dict[Type[Sensor], list[Sensor]] = {}
16-
self.cache: dict[Type[Sensor], torch.Tensor] = {}
17-
self.cache_size_map: dict[Type[Sensor], int] = {}
17+
def __init__(self, sim):
18+
self._sim = sim
19+
self._sensors_by_type: dict[Type["Sensor"], list["Sensor"]] = {}
20+
self._sensors_metadata: dict[Type["Sensor"], dict[str, Any]] = {}
21+
self._cache: dict[Type["Sensor"], torch.Tensor] = {}
22+
self._cache_size_map: dict[Type["Sensor"], int] = {}
23+
self._cache_last_updated_step_map: dict[Type["Sensor"], int] = {}
1824

1925
def create_sensor(self, sensor_options: SensorOptions):
2026
sensor_cls = SensorManager.SENSOR_TYPES_MAP[type(sensor_options)]
21-
sensor = sensor_cls(sensor_options, len(self.sensors_by_type[sensor_cls]), self)
22-
if sensor_cls not in self.sensors_by_type:
23-
self.sensors_by_type[sensor_cls] = []
24-
self.sensors_by_type[sensor_cls].append(sensor)
27+
if sensor_cls not in self._sensors_by_type:
28+
self._sensors_by_type[sensor_cls] = []
29+
sensor = sensor_cls(sensor_options, len(self._sensors_by_type[sensor_cls]), self)
30+
self._sensors_by_type[sensor_cls].append(sensor)
2531
return sensor
2632

2733
def build(self):
28-
for sensor_cls, sensors in self.sensors_by_type.items():
34+
for sensor_cls, sensors in self._sensors_by_type.items():
2935
total_cache_length = 0
36+
self._cache_last_updated_step_map[sensor_cls] = -1
37+
self._sensors_metadata[sensor_cls] = {}
3038
for sensor in sensors:
3139
sensor.build()
3240
sensor._cache_idx = total_cache_length
3341
total_cache_length += sensor.cache_length
3442

3543
cache_dtype, cache_shape = SensorManager.SENSOR_CACHE_METADATA_MAP[sensor_cls]
36-
self.cache[sensor_cls] = torch.zeros((total_cache_length, *cache_shape), dtype=cache_dtype)
37-
38-
for sensor in sensors:
39-
sensor.build()
44+
self._cache[sensor_cls] = torch.zeros((self._sim._B, total_cache_length, *cache_shape), dtype=cache_dtype)
4045

41-
def get_sensor_cache(self, sensor_cls: Type[Sensor], sensor_idx: int | None = None) -> torch.Tensor:
42-
cache_size = SensorManager.SENSOR_CACHE_SIZE_MAP[sensor_cls]
43-
if sensor_idx is None:
44-
return self.cache[sensor_cls]
45-
return self.cache[sensor_cls][sensor_idx * cache_size : (sensor_idx + 1) * cache_size]
46+
def is_cache_updated(self, sensor_cls: Type["Sensor"]) -> bool:
47+
return self._cache_last_updated_step_map[sensor_cls] == self._sim.cur_step_global
4648

47-
def set_sensor_cache(self, new_values: torch.Tensor, sensor_cls: Type[Sensor], sensor_idx: int | None = None):
48-
cache_size = SensorManager.SENSOR_CACHE_SIZE_MAP[sensor_cls]
49-
if sensor_idx is None:
50-
self.cache[sensor_cls] = new_values
51-
else:
52-
self.cache[sensor_cls][sensor_idx * cache_size : (sensor_idx + 1) * cache_size] = new_values
49+
def set_cache_updated(self, sensor_cls: Type["Sensor"]):
50+
self._cache_last_updated_step_map[sensor_cls] = self._sim.cur_step_global
5351

5452
@property
5553
def sensors(self):
56-
return [sensor for sensor_list in self.sensors_by_type.values() for sensor in sensor_list]
54+
return [sensor for sensor_list in self._sensors_by_type.values() for sensor in sensor_list]
5755

5856

59-
def register_sensor(sensor_cls: Type[Sensor], cache_dtype: torch.dtype, cache_shape: tuple[int, ...]):
60-
def _impl(sensor_options: SensorOptions):
61-
SensorManager.SENSOR_TYPES_MAP[type(sensor_options)] = sensor_cls
62-
SensorManager.SENSOR_CACHE_METADATA_MAP[sensor_cls] = (cache_dtype, cache_shape)
63-
return sensor_options
57+
def register_sensor(sensor_cls: Type["Sensor"]):
58+
def _impl(options_cls: Type[SensorOptions]):
59+
SensorManager.SENSOR_TYPES_MAP[options_cls] = sensor_cls
60+
SensorManager.SENSOR_CACHE_METADATA_MAP[sensor_cls] = (sensor_cls.CACHE_DTYPE, sensor_cls.CACHE_SHAPE)
61+
return options_cls
6462

6563
return _impl

0 commit comments

Comments
 (0)