|
1 | | -from typing import Type |
| 1 | +from typing import Any, Type |
2 | 2 |
|
3 | 3 | import torch |
4 | 4 |
|
5 | 5 | from genesis.options.sensors import SensorOptions |
6 | 6 |
|
7 | | -from .base_sensor import Sensor |
| 7 | +from typing import TYPE_CHECKING |
| 8 | + |
| 9 | +if TYPE_CHECKING: |
| 10 | + from .base_sensor import Sensor |
8 | 11 |
|
9 | 12 |
|
10 | 13 | class SensorManager: |
11 | | - SENSOR_TYPES_MAP: dict[Type[SensorOptions], Type[Sensor]] = {} |
12 | | - SENSOR_CACHE_METADATA_MAP: dict[Type[Sensor], (torch.dtype, tuple[int, ...])] = {} |
| 14 | + SENSOR_TYPES_MAP: dict[Type[SensorOptions], Type["Sensor"]] = {} |
| 15 | + SENSOR_CACHE_METADATA_MAP: dict[Type["Sensor"], (torch.dtype, tuple[int, ...])] = {} |
13 | 16 |
|
14 | | - def __init__(self): |
15 | | - self.sensors_by_type: dict[Type[Sensor], list[Sensor]] = {} |
16 | | - self.cache: dict[Type[Sensor], torch.Tensor] = {} |
17 | | - self.cache_size_map: dict[Type[Sensor], int] = {} |
| 17 | + def __init__(self, sim): |
| 18 | + self._sim = sim |
| 19 | + self._sensors_by_type: dict[Type["Sensor"], list["Sensor"]] = {} |
| 20 | + self._sensors_metadata: dict[Type["Sensor"], dict[str, Any]] = {} |
| 21 | + self._cache: dict[Type["Sensor"], torch.Tensor] = {} |
| 22 | + self._cache_size_map: dict[Type["Sensor"], int] = {} |
| 23 | + self._cache_last_updated_step_map: dict[Type["Sensor"], int] = {} |
18 | 24 |
|
19 | 25 | def create_sensor(self, sensor_options: SensorOptions): |
20 | 26 | sensor_cls = SensorManager.SENSOR_TYPES_MAP[type(sensor_options)] |
21 | | - sensor = sensor_cls(sensor_options, len(self.sensors_by_type[sensor_cls]), self) |
22 | | - if sensor_cls not in self.sensors_by_type: |
23 | | - self.sensors_by_type[sensor_cls] = [] |
24 | | - self.sensors_by_type[sensor_cls].append(sensor) |
| 27 | + if sensor_cls not in self._sensors_by_type: |
| 28 | + self._sensors_by_type[sensor_cls] = [] |
| 29 | + sensor = sensor_cls(sensor_options, len(self._sensors_by_type[sensor_cls]), self) |
| 30 | + self._sensors_by_type[sensor_cls].append(sensor) |
25 | 31 | return sensor |
26 | 32 |
|
27 | 33 | def build(self): |
28 | | - for sensor_cls, sensors in self.sensors_by_type.items(): |
| 34 | + for sensor_cls, sensors in self._sensors_by_type.items(): |
29 | 35 | total_cache_length = 0 |
| 36 | + self._cache_last_updated_step_map[sensor_cls] = -1 |
| 37 | + self._sensors_metadata[sensor_cls] = {} |
30 | 38 | for sensor in sensors: |
31 | 39 | sensor.build() |
32 | 40 | sensor._cache_idx = total_cache_length |
33 | 41 | total_cache_length += sensor.cache_length |
34 | 42 |
|
35 | 43 | cache_dtype, cache_shape = SensorManager.SENSOR_CACHE_METADATA_MAP[sensor_cls] |
36 | | - self.cache[sensor_cls] = torch.zeros((total_cache_length, *cache_shape), dtype=cache_dtype) |
37 | | - |
38 | | - for sensor in sensors: |
39 | | - sensor.build() |
| 44 | + self._cache[sensor_cls] = torch.zeros((self._sim._B, total_cache_length, *cache_shape), dtype=cache_dtype) |
40 | 45 |
|
41 | | - def get_sensor_cache(self, sensor_cls: Type[Sensor], sensor_idx: int | None = None) -> torch.Tensor: |
42 | | - cache_size = SensorManager.SENSOR_CACHE_SIZE_MAP[sensor_cls] |
43 | | - if sensor_idx is None: |
44 | | - return self.cache[sensor_cls] |
45 | | - return self.cache[sensor_cls][sensor_idx * cache_size : (sensor_idx + 1) * cache_size] |
| 46 | + def is_cache_updated(self, sensor_cls: Type["Sensor"]) -> bool: |
| 47 | + return self._cache_last_updated_step_map[sensor_cls] == self._sim.cur_step_global |
46 | 48 |
|
47 | | - def set_sensor_cache(self, new_values: torch.Tensor, sensor_cls: Type[Sensor], sensor_idx: int | None = None): |
48 | | - cache_size = SensorManager.SENSOR_CACHE_SIZE_MAP[sensor_cls] |
49 | | - if sensor_idx is None: |
50 | | - self.cache[sensor_cls] = new_values |
51 | | - else: |
52 | | - self.cache[sensor_cls][sensor_idx * cache_size : (sensor_idx + 1) * cache_size] = new_values |
| 49 | + def set_cache_updated(self, sensor_cls: Type["Sensor"]): |
| 50 | + self._cache_last_updated_step_map[sensor_cls] = self._sim.cur_step_global |
53 | 51 |
|
54 | 52 | @property |
55 | 53 | def sensors(self): |
56 | | - return [sensor for sensor_list in self.sensors_by_type.values() for sensor in sensor_list] |
| 54 | + return [sensor for sensor_list in self._sensors_by_type.values() for sensor in sensor_list] |
57 | 55 |
|
58 | 56 |
|
59 | | -def register_sensor(sensor_cls: Type[Sensor], cache_dtype: torch.dtype, cache_shape: tuple[int, ...]): |
60 | | - def _impl(sensor_options: SensorOptions): |
61 | | - SensorManager.SENSOR_TYPES_MAP[type(sensor_options)] = sensor_cls |
62 | | - SensorManager.SENSOR_CACHE_METADATA_MAP[sensor_cls] = (cache_dtype, cache_shape) |
63 | | - return sensor_options |
| 57 | +def register_sensor(sensor_cls: Type["Sensor"]): |
| 58 | + def _impl(options_cls: Type[SensorOptions]): |
| 59 | + SensorManager.SENSOR_TYPES_MAP[options_cls] = sensor_cls |
| 60 | + SensorManager.SENSOR_CACHE_METADATA_MAP[sensor_cls] = (sensor_cls.CACHE_DTYPE, sensor_cls.CACHE_SHAPE) |
| 61 | + return options_cls |
64 | 62 |
|
65 | 63 | return _impl |
0 commit comments