|
| 1 | +from typing import Type |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | +from genesis.options.sensors import SensorOptions |
| 6 | + |
| 7 | +from .base_sensor import Sensor |
| 8 | + |
| 9 | + |
| 10 | +class SensorManager: |
| 11 | + SENSOR_TYPES_MAP: dict[Type[SensorOptions], Type[Sensor]] = {} |
| 12 | + SENSOR_CACHE_METADATA_MAP: dict[Type[Sensor], (torch.dtype, tuple[int, ...])] = {} |
| 13 | + |
| 14 | + def __init__(self): |
| 15 | + self.sensors_by_type: dict[Type[Sensor], list[Sensor]] = {} |
| 16 | + self.cache: dict[Type[Sensor], torch.Tensor] = {} |
| 17 | + self.cache_size_map: dict[Type[Sensor], int] = {} |
| 18 | + |
| 19 | + def create_sensor(self, sensor_options: SensorOptions): |
| 20 | + sensor_cls = SensorManager.SENSOR_TYPES_MAP[type(sensor_options)] |
| 21 | + sensor = sensor_cls(sensor_options, len(self.sensors_by_type[sensor_cls]), self) |
| 22 | + if sensor_cls not in self.sensors_by_type: |
| 23 | + self.sensors_by_type[sensor_cls] = [] |
| 24 | + self.sensors_by_type[sensor_cls].append(sensor) |
| 25 | + return sensor |
| 26 | + |
| 27 | + def build(self): |
| 28 | + for sensor_cls, sensors in self.sensors_by_type.items(): |
| 29 | + total_cache_length = 0 |
| 30 | + for sensor in sensors: |
| 31 | + sensor.build() |
| 32 | + sensor._cache_idx = total_cache_length |
| 33 | + total_cache_length += sensor.cache_length |
| 34 | + |
| 35 | + cache_dtype, cache_shape = SensorManager.SENSOR_CACHE_METADATA_MAP[sensor_cls] |
| 36 | + self.cache[sensor_cls] = torch.zeros((total_cache_length, *cache_shape), dtype=cache_dtype) |
| 37 | + |
| 38 | + for sensor in sensors: |
| 39 | + sensor.build() |
| 40 | + |
| 41 | + def get_sensor_cache(self, sensor_cls: Type[Sensor], sensor_idx: int | None = None) -> torch.Tensor: |
| 42 | + cache_size = SensorManager.SENSOR_CACHE_SIZE_MAP[sensor_cls] |
| 43 | + if sensor_idx is None: |
| 44 | + return self.cache[sensor_cls] |
| 45 | + return self.cache[sensor_cls][sensor_idx * cache_size : (sensor_idx + 1) * cache_size] |
| 46 | + |
| 47 | + def set_sensor_cache(self, new_values: torch.Tensor, sensor_cls: Type[Sensor], sensor_idx: int | None = None): |
| 48 | + cache_size = SensorManager.SENSOR_CACHE_SIZE_MAP[sensor_cls] |
| 49 | + if sensor_idx is None: |
| 50 | + self.cache[sensor_cls] = new_values |
| 51 | + else: |
| 52 | + self.cache[sensor_cls][sensor_idx * cache_size : (sensor_idx + 1) * cache_size] = new_values |
| 53 | + |
| 54 | + @property |
| 55 | + def sensors(self): |
| 56 | + return [sensor for sensor_list in self.sensors_by_type.values() for sensor in sensor_list] |
| 57 | + |
| 58 | + |
| 59 | +def register_sensor(sensor_cls: Type[Sensor], cache_dtype: torch.dtype, cache_shape: tuple[int, ...]): |
| 60 | + def _impl(sensor_options: SensorOptions): |
| 61 | + SensorManager.SENSOR_TYPES_MAP[type(sensor_options)] = sensor_cls |
| 62 | + SensorManager.SENSOR_CACHE_METADATA_MAP[sensor_cls] = (cache_dtype, cache_shape) |
| 63 | + return sensor_options |
| 64 | + |
| 65 | + return _impl |
0 commit comments