|
| 1 | +from typing import TYPE_CHECKING, Any, List, Type |
| 2 | + |
| 3 | +import numpy as np |
1 | 4 | import taichi as ti |
2 | | -import genesis as gs |
| 5 | +import torch |
3 | 6 |
|
4 | | -from typing import List, Optional |
| 7 | +import genesis as gs |
5 | 8 | from genesis.repr_base import RBC |
6 | 9 |
|
| 10 | +if TYPE_CHECKING: |
| 11 | + from genesis.options.sensors import SensorOptions |
| 12 | + from genesis.utils.ring_buffer import TensorRingBuffer |
| 13 | + |
| 14 | + from .sensor_manager import SensorManager |
| 15 | + |
7 | 16 |
|
8 | 17 | @ti.data_oriented |
9 | 18 | class Sensor(RBC): |
10 | 19 | """ |
11 | 20 | Base class for all types of sensors. |
12 | | - A sensor must have a read() method that returns the sensor data. |
13 | 21 | """ |
14 | 22 |
|
| 23 | + def __init__(self, sensor_options: "SensorOptions", sensor_idx: int, sensor_manager: "SensorManager"): |
| 24 | + self._options: "SensorOptions" = sensor_options |
| 25 | + self._idx: int = sensor_idx |
| 26 | + self._manager: "SensorManager" = sensor_manager |
| 27 | + |
| 28 | + # initialized by SensorManager during build |
| 29 | + self._read_delay_steps: int = 0 |
| 30 | + self._shape_indices: list[tuple[int, int]] = [] |
| 31 | + self._shared_metadata: dict[str, Any] | None = None |
| 32 | + self._cache: "TensorRingBuffer" | None = None |
| 33 | + |
| 34 | + # =============================== implementable methods =============================== |
| 35 | + |
| 36 | + def build(self): |
| 37 | + """ |
| 38 | + This method is called by the SensorManager during the scene build phase to initialize the sensor. |
| 39 | + This is where any shared metadata should be initialized. |
| 40 | + """ |
| 41 | + raise NotImplementedError("Sensors must implement `build()`.") |
| 42 | + |
| 43 | + def _get_return_format(self) -> dict[str, tuple[int, ...]] | tuple[int, ...]: |
| 44 | + """ |
| 45 | + Data format of the read() return value. |
| 46 | +
|
| 47 | + Returns |
| 48 | + ------- |
| 49 | + return_format : dict | tuple |
| 50 | + - If tuple, the final shape of the read() return value. |
| 51 | + e.g. (2, 3) means read() will return a tensor of shape (2, 3). |
| 52 | + - If dict a dictionary with string keys and tensor values will be returned. |
| 53 | + e.g. {"pos": (3,), "quat": (4,)} returns a dict of tensors [0:3] and [3:7] from the cache. |
| 54 | + """ |
| 55 | + raise NotImplementedError("Sensors must implement `return_format()`.") |
| 56 | + |
| 57 | + def _get_cache_length(self) -> int: |
| 58 | + """ |
| 59 | + The length of the cache for this sensor instance, e.g. number of points for a Lidar point cloud. |
| 60 | + """ |
| 61 | + raise NotImplementedError("Sensors must implement `cache_length()`.") |
| 62 | + |
| 63 | + @classmethod |
| 64 | + def _update_shared_ground_truth_cache( |
| 65 | + cls, shared_metadata: dict[str, Any], shared_ground_truth_cache: torch.Tensor |
| 66 | + ): |
| 67 | + """ |
| 68 | + Update the shared sensor ground truth cache for all sensors of this class using metadata in SensorManager. |
| 69 | + """ |
| 70 | + raise NotImplementedError("Sensors must implement `update_shared_ground_truth_cache()`.") |
| 71 | + |
| 72 | + @classmethod |
| 73 | + def _update_shared_cache( |
| 74 | + cls, shared_metadata: dict[str, Any], shared_ground_truth_cache: torch.Tensor, shared_cache: "TensorRingBuffer" |
| 75 | + ): |
| 76 | + """ |
| 77 | + Update the shared sensor cache for all sensors of this class using metadata in SensorManager. |
| 78 | + """ |
| 79 | + raise NotImplementedError("Sensors must implement `update_shared_cache()`.") |
| 80 | + |
| 81 | + @classmethod |
| 82 | + def _get_cache_dtype(cls) -> torch.dtype: |
| 83 | + """ |
| 84 | + The dtype of the cache for this sensor. |
| 85 | + """ |
| 86 | + raise NotImplementedError("Sensors must implement `get_cache_dtype()`.") |
| 87 | + |
| 88 | + # =============================== shared methods =============================== |
| 89 | + |
15 | 90 | @gs.assert_built |
16 | | - def read(self, envs_idx: Optional[List[int]] = None): |
| 91 | + def read(self, envs_idx: List[int] | None = None): |
17 | 92 | """ |
18 | | - Read the sensor data. |
19 | | - Sensor implementations should ideally cache the data to avoid unnecessary computations. |
| 93 | + Read the sensor data (with noise applied if applicable). |
20 | 94 | """ |
21 | | - raise NotImplementedError("The Sensor subclass must implement `read()`.") |
| 95 | + return self._get_formatted_data(self._cache.get(self._read_delay_steps), envs_idx) |
| 96 | + |
| 97 | + @gs.assert_built |
| 98 | + def read_ground_truth(self, envs_idx: List[int] | None = None): |
| 99 | + """ |
| 100 | + Read the ground truth sensor data (without noise). |
| 101 | + """ |
| 102 | + return self._get_formatted_data(self._manager.get_cloned_from_ground_truth_cache(self), envs_idx) |
| 103 | + |
| 104 | + def _get_formatted_data( |
| 105 | + self, tensor: torch.Tensor, envs_idx: list[int] | None |
| 106 | + ) -> torch.Tensor | dict[str, torch.Tensor]: |
| 107 | + # Note: This method does not clone the data tensor, it should have been cloned by the caller. |
| 108 | + |
| 109 | + if envs_idx is None: |
| 110 | + envs_idx = self._manager._sim._scene._envs_idx |
| 111 | + |
| 112 | + return_format = self._get_return_format() |
| 113 | + return_shapes = return_format.values() if isinstance(return_format, dict) else (return_format,) |
| 114 | + return_values = [] |
| 115 | + |
| 116 | + for i, shape in enumerate(return_shapes): |
| 117 | + start_idx, end_idx = self._shape_indices[i] |
| 118 | + value = tensor[envs_idx, start_idx:end_idx].reshape(len(envs_idx), *shape).squeeze() |
| 119 | + if self._manager._sim.n_envs == 0: |
| 120 | + value = value.squeeze(0) |
| 121 | + return_values.append(value) |
| 122 | + |
| 123 | + if isinstance(return_format, dict): |
| 124 | + return dict(zip(return_format.keys(), return_values)) |
| 125 | + else: |
| 126 | + return return_values[0] |
| 127 | + |
| 128 | + @property |
| 129 | + def is_built(self) -> bool: |
| 130 | + return self._manager._sim._scene._is_built |
0 commit comments