1- from typing import TYPE_CHECKING , Any , List , Type
1+ from dataclasses import dataclass , field
2+ from typing import TYPE_CHECKING , Any , List
23
34import numpy as np
45import taichi as ti
56import torch
67
78import genesis as gs
9+ from genesis .options import Options
810from genesis .repr_base import RBC
911
1012if TYPE_CHECKING :
11- from genesis .options .sensors import SensorOptions
1213 from genesis .utils .ring_buffer import TensorRingBuffer
1314
1415 from .sensor_manager import SensorManager
1516
1617
18+ class SensorOptions (Options ):
19+ """
20+ Base class for all sensor options.
21+ Each sensor should have their own options class that inherits from this class.
22+ The options class should be registered with the SensorManager using the @register_sensor decorator.
23+
24+ Parameters
25+ ----------
26+ read_delay : float
27+ The delay in seconds before the sensor data is read.
28+ """
29+
30+ read_delay : float = 0.0
31+
32+ def validate (self , scene ):
33+ """
34+ Validate the sensor options values before the sensor is added to the scene.
35+ """
36+ read_delay_hz = self .read_delay / scene ._sim .dt
37+ if not np .isclose (read_delay_hz , round (read_delay_hz ), atol = 1e-6 ):
38+ gs .logger .warn (
39+ f"Read delay should be a multiple of the simulation time step. Got { self .read_delay } "
40+ f" and { scene ._sim .dt } . Actual read delay will be { 1 / round (read_delay_hz )} ."
41+ )
42+
43+
44+ @dataclass
45+ class SharedSensorMetadata :
46+ """
47+ Shared metadata between all sensors of the same class.
48+ """
49+
50+ cache_sizes : list [int ] = field (default_factory = list )
51+ read_delay_steps : list [int ] = field (default_factory = list )
52+
53+
1754@ti .data_oriented
1855class Sensor (RBC ):
1956 """
2057 Base class for all types of sensors.
58+
59+ NOTE: The Sensor system is designed to be performant. All sensors of the same type are updated at once and stored
60+ in a cache in SensorManager. Cache size is inferred from the return format and cache length of each sensor.
61+ `read()` and `read_ground_truth()`, the public-facing methods of every Sensor, automatically handles indexing into
62+ the shared cache to return the correct data.
2163 """
2264
2365 def __init__ (self , sensor_options : "SensorOptions" , sensor_idx : int , sensor_manager : "SensorManager" ):
2466 self ._options : "SensorOptions" = sensor_options
2567 self ._idx : int = sensor_idx
2668 self ._manager : "SensorManager" = sensor_manager
69+ self ._shared_metadata : SharedSensorMetadata = sensor_manager ._sensors_metadata [type (self )]
70+
71+ self ._read_delay_steps = round (self ._options .read_delay / self ._manager ._sim .dt )
72+ self ._shared_metadata .read_delay_steps .append (self ._read_delay_steps )
2773
28- # initialized by SensorManager during build
29- self ._read_delay_steps : int = 0
3074 self ._shape_indices : list [tuple [int , int ]] = []
31- self ._shared_metadata : dict [str , Any ] | None = None
32- self ._cache : "TensorRingBuffer" | None = None
75+ return_format = self ._get_return_format ()
76+ return_shapes = return_format .values () if isinstance (return_format , dict ) else (return_format ,)
77+ tensor_size = 0
78+ for shape in return_shapes :
79+ data_size = np .prod (shape )
80+ self ._shape_indices .append ((tensor_size , tensor_size + data_size ))
81+ tensor_size += data_size
82+
83+ self ._cache_size = self ._get_cache_length () * tensor_size
84+ self ._shared_metadata .cache_sizes .append (self ._cache_size )
85+
86+ self ._cache_idx : int = - 1 # initialized by SensorManager during build
3387
3488 # =============================== implementable methods ===============================
3589
@@ -71,12 +125,19 @@ def _update_shared_ground_truth_cache(
71125
72126 @classmethod
73127 def _update_shared_cache (
74- cls , shared_metadata : dict [str , Any ], shared_ground_truth_cache : torch .Tensor , shared_cache : "TensorRingBuffer"
128+ cls ,
129+ shared_metadata : dict [str , Any ],
130+ shared_ground_truth_cache : torch .Tensor ,
131+ shared_cache : torch .Tensor ,
132+ buffered_data : "TensorRingBuffer" ,
75133 ):
76134 """
77135 Update the shared sensor cache for all sensors of this class using metadata in SensorManager.
136+
137+ The information in shared_cache should be the final measured sensor data after all noise and post-processing.
138+ NOTE: The implementation should include applying the delay using the `_apply_delay_to_shared_cache()` method.
78139 """
79- raise NotImplementedError ("Sensors must implement `update_shared_cache ()`." )
140+ raise NotImplementedError ("Sensors must implement `update_shared_cache_with_noise ()`." )
80141
81142 @classmethod
82143 def _get_cache_dtype (cls ) -> torch .dtype :
@@ -92,19 +153,35 @@ def read(self, envs_idx: List[int] | None = None):
92153 """
93154 Read the sensor data (with noise applied if applicable).
94155 """
95- return self ._get_formatted_data (self ._cache . get (self . _read_delay_steps ), envs_idx )
156+ return self ._get_formatted_data (self ._manager . get_cloned_from_cache (self ), envs_idx )
96157
97158 @gs .assert_built
98159 def read_ground_truth (self , envs_idx : List [int ] | None = None ):
99160 """
100161 Read the ground truth sensor data (without noise).
101162 """
102- return self ._get_formatted_data (self ._manager .get_cloned_from_ground_truth_cache (self ), envs_idx )
163+ return self ._get_formatted_data (self ._manager .get_cloned_from_cache (self , is_ground_truth = True ), envs_idx )
164+
165+ @classmethod
166+ def _apply_delay_to_shared_cache (
167+ self , shared_metadata : SharedSensorMetadata , shared_cache : torch .Tensor , buffered_data : "TensorRingBuffer"
168+ ):
169+ """
170+ Applies the read delay to the shared cache tensor by copying the buffered data at the appropriate index.
171+ """
172+ idx = 0
173+ for tensor_size , read_delay_step in zip (shared_metadata .cache_sizes , shared_metadata .read_delay_steps ):
174+ shared_cache [:, idx : idx + tensor_size ] = buffered_data .at (read_delay_step )[:, idx : idx + tensor_size ]
175+ idx += tensor_size
103176
104177 def _get_formatted_data (
105178 self , tensor : torch .Tensor , envs_idx : list [int ] | None
106179 ) -> torch .Tensor | dict [str , torch .Tensor ]:
107- # Note: This method does not clone the data tensor, it should have been cloned by the caller.
180+ """
181+ Formats the flattened cache tensor into a dict of tensors using the format specified in `_get_return_format()`.
182+
183+ NOTE: This method does not clone the data tensor, it should have been cloned by the caller.
184+ """
108185
109186 if envs_idx is None :
110187 envs_idx = self ._manager ._sim ._scene ._envs_idx
0 commit comments