Skip to content

Commit 4a2987e

Browse files
Milotrinceduburcqa
andauthored
[FEATURE] Add IMU sensor. (#1551)
Co-authored-by: Alexis Duburcq <alexis.duburcq@gmail.com>
1 parent 8b5733d commit 4a2987e

File tree

13 files changed

+424
-124
lines changed

13 files changed

+424
-124
lines changed

genesis/engine/entities/rigid_entity/rigid_entity.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1730,27 +1730,6 @@ def get_links_ang(self, links_idx_local=None, envs_idx=None, *, unsafe=False):
17301730
links_idx = self._get_idx(links_idx_local, self.n_links, self._link_start, unsafe=True)
17311731
return self._solver.get_links_ang(links_idx, envs_idx, unsafe=unsafe)
17321732

1733-
@gs.assert_built
1734-
def get_links_accelerometer_data(self, links_idx_local=None, envs_idx=None, *, imu=False, unsafe=False):
1735-
"""
1736-
Returns the accelerometer data that would be measured by a IMU rigidly attached to the specified entity's links,
1737-
i.e. the true linear acceleration of the links expressed at their respective origin in local frame coordinates.
1738-
1739-
Parameters
1740-
----------
1741-
links_idx_local : None | array_like
1742-
The indices of the links. Defaults to None.
1743-
envs_idx : None | array_like, optional
1744-
The indices of the environments. If None, all environments will be considered. Defaults to None.
1745-
1746-
Returns
1747-
-------
1748-
acc : torch.Tensor, shape (n_links, 3) or (n_envs, n_links, 3)
1749-
The accelerometer data of IMUs rigidly attached of the specified entity's links.
1750-
"""
1751-
links_idx = self._get_idx(links_idx_local, self.n_links, self._link_start, unsafe=True)
1752-
return self._solver.get_links_acc(links_idx, envs_idx, mimick_imu=True, unsafe=unsafe)
1753-
17541733
@gs.assert_built
17551734
def get_links_acc(self, links_idx_local=None, envs_idx=None, *, unsafe=False):
17561735
"""
@@ -1770,7 +1749,7 @@ def get_links_acc(self, links_idx_local=None, envs_idx=None, *, unsafe=False):
17701749
The linear classical acceleration of the specified entity's links.
17711750
"""
17721751
links_idx = self._get_idx(links_idx_local, self.n_links, self._link_start, unsafe=True)
1773-
return self._solver.get_links_acc(links_idx, envs_idx, mimick_imu=False, unsafe=unsafe)
1752+
return self._solver.get_links_acc(links_idx, envs_idx, unsafe=unsafe)
17741753

17751754
@gs.assert_built
17761755
def get_links_acc_ang(self, links_idx_local=None, envs_idx=None, *, unsafe=False):

genesis/engine/scene.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pickle
33
import sys
44
import time
5+
from typing import TYPE_CHECKING
56

67
import numpy as np
78
import torch
@@ -26,7 +27,6 @@
2627
PBDOptions,
2728
ProfilingOptions,
2829
RigidOptions,
29-
SensorOptions,
3030
SFOptions,
3131
SimOptions,
3232
SPHOptions,
@@ -43,6 +43,9 @@
4343
from genesis.vis import Visualizer
4444
from genesis.utils.warnings import warn_once
4545

46+
if TYPE_CHECKING:
47+
from genesis.sensors.base_sensor import SensorOptions
48+
4649

4750
@gs.assert_initialized
4851
class Scene(RBC):
@@ -516,7 +519,7 @@ def add_light(
516519
gs.raise_exception("Adding lights is only supported by 'RayTracer' and 'BatchRenderer'.")
517520

518521
@gs.assert_unbuilt
519-
def add_sensor(self, sensor_options: SensorOptions):
522+
def add_sensor(self, sensor_options: "SensorOptions"):
520523
return self._sim._sensor_manager.create_sensor(sensor_options)
521524

522525
@gs.assert_unbuilt

genesis/engine/solvers/base_solver.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ def _kernel_set_gravity(self, gravity: ti.types.ndarray(), envs_idx: ti.types.nd
6363
for j in ti.static(range(3)):
6464
self._gravity[envs_idx[i_b_]][j] = gravity[i_b_, j]
6565

66+
def get_gravity(self, envs_idx=None, *, unsafe=False):
67+
tensor = ti_field_to_torch(self._gravity, envs_idx, transpose=True, unsafe=unsafe)
68+
return tensor.squeeze(0) if self.n_envs == 0 else tensor
69+
6670
def dump_ckpt_to_numpy(self) -> dict[str, np.ndarray]:
6771
arrays: dict[str, np.ndarray] = {}
6872

genesis/engine/solvers/rigid/rigid_solver_decomp.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2043,18 +2043,16 @@ def get_links_ang(self, links_idx=None, envs_idx=None, *, unsafe=False):
20432043
tensor = ti_field_to_torch(self.links_state.cd_ang, envs_idx, links_idx, transpose=True, unsafe=unsafe)
20442044
return tensor.squeeze(0) if self.n_envs == 0 else tensor
20452045

2046-
def get_links_acc(self, links_idx=None, envs_idx=None, *, mimick_imu=False, unsafe=False):
2046+
def get_links_acc(self, links_idx=None, envs_idx=None, *, unsafe=False):
20472047
_tensor, links_idx, envs_idx = self._sanitize_2D_io_variables(
20482048
None, links_idx, self.n_links, 3, envs_idx, idx_name="links_idx", unsafe=unsafe
20492049
)
20502050
tensor = _tensor.unsqueeze(0) if self.n_envs == 0 else _tensor
20512051
kernel_get_links_acc(
2052-
mimick_imu,
20532052
tensor,
20542053
links_idx,
20552054
envs_idx,
20562055
self.links_state,
2057-
self._rigid_global_info,
20582056
self._static_rigid_sim_config,
20592057
)
20602058
return _tensor
@@ -6554,12 +6552,10 @@ def kernel_get_links_vel(
65546552

65556553
@ti.kernel
65566554
def kernel_get_links_acc(
6557-
mimick_imu: ti.i32,
65586555
tensor: ti.types.ndarray(),
65596556
links_idx: ti.types.ndarray(),
65606557
envs_idx: ti.types.ndarray(),
65616558
links_state: array_class.LinksState,
6562-
rigid_global_info: array_class.RigidGlobalInfo,
65636559
static_rigid_sim_config: ti.template(),
65646560
):
65656561
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)
@@ -6577,14 +6573,6 @@ def kernel_get_links_acc(
65776573
vel = links_state.cd_vel[i_l, i_b] + ang.cross(cpos)
65786574
acc_classic_lin = acc_lin + ang.cross(vel)
65796575

6580-
# Mimick IMU accelerometer signal if requested
6581-
if mimick_imu:
6582-
# Subtract gravity
6583-
acc_classic_lin -= rigid_global_info.gravity[i_b]
6584-
6585-
# Move the resulting linear acceleration in local links frame
6586-
acc_classic_lin = gu.ti_inv_transform_by_quat(acc_classic_lin, links_state.quat[i_l, i_b])
6587-
65886576
for i in ti.static(range(3)):
65896577
tensor[i_b_, i_l_, i] = acc_classic_lin[i]
65906578

genesis/options/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,5 @@
22
from .solvers import *
33
from .vis import *
44
from .profiling import ProfilingOptions
5-
from .sensors import SensorOptions
65

76
__all__ = ["ProfilingOptions"]

genesis/options/sensors.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

genesis/sensors/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .base_sensor import Sensor
2+
from .imu import IMU
23
from .tactile import RigidContactSensor, RigidContactForceSensor, RigidContactForceGridSensor
34
from .data_recorder import SensorDataRecorder, RecordingOptions
45
from .data_handlers import (

genesis/sensors/base_sensor.py

Lines changed: 88 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,89 @@
1-
from typing import TYPE_CHECKING, Any, List, Type
1+
from dataclasses import dataclass, field
2+
from typing import TYPE_CHECKING, Any, List
23

34
import numpy as np
45
import taichi as ti
56
import torch
67

78
import genesis as gs
9+
from genesis.options import Options
810
from genesis.repr_base import RBC
911

1012
if TYPE_CHECKING:
11-
from genesis.options.sensors import SensorOptions
1213
from genesis.utils.ring_buffer import TensorRingBuffer
1314

1415
from .sensor_manager import SensorManager
1516

1617

18+
class SensorOptions(Options):
19+
"""
20+
Base class for all sensor options.
21+
Each sensor should have their own options class that inherits from this class.
22+
The options class should be registered with the SensorManager using the @register_sensor decorator.
23+
24+
Parameters
25+
----------
26+
read_delay : float
27+
The delay in seconds before the sensor data is read.
28+
"""
29+
30+
read_delay: float = 0.0
31+
32+
def validate(self, scene):
33+
"""
34+
Validate the sensor options values before the sensor is added to the scene.
35+
"""
36+
read_delay_hz = self.read_delay / scene._sim.dt
37+
if not np.isclose(read_delay_hz, round(read_delay_hz), atol=1e-6):
38+
gs.logger.warn(
39+
f"Read delay should be a multiple of the simulation time step. Got {self.read_delay}"
40+
f" and {scene._sim.dt}. Actual read delay will be {1/round(read_delay_hz)}."
41+
)
42+
43+
44+
@dataclass
45+
class SharedSensorMetadata:
46+
"""
47+
Shared metadata between all sensors of the same class.
48+
"""
49+
50+
cache_sizes: list[int] = field(default_factory=list)
51+
read_delay_steps: list[int] = field(default_factory=list)
52+
53+
1754
@ti.data_oriented
1855
class Sensor(RBC):
1956
"""
2057
Base class for all types of sensors.
58+
59+
NOTE: The Sensor system is designed to be performant. All sensors of the same type are updated at once and stored
60+
in a cache in SensorManager. Cache size is inferred from the return format and cache length of each sensor.
61+
`read()` and `read_ground_truth()`, the public-facing methods of every Sensor, automatically handles indexing into
62+
the shared cache to return the correct data.
2163
"""
2264

2365
def __init__(self, sensor_options: "SensorOptions", sensor_idx: int, sensor_manager: "SensorManager"):
2466
self._options: "SensorOptions" = sensor_options
2567
self._idx: int = sensor_idx
2668
self._manager: "SensorManager" = sensor_manager
69+
self._shared_metadata: SharedSensorMetadata = sensor_manager._sensors_metadata[type(self)]
70+
71+
self._read_delay_steps = round(self._options.read_delay / self._manager._sim.dt)
72+
self._shared_metadata.read_delay_steps.append(self._read_delay_steps)
2773

28-
# initialized by SensorManager during build
29-
self._read_delay_steps: int = 0
3074
self._shape_indices: list[tuple[int, int]] = []
31-
self._shared_metadata: dict[str, Any] | None = None
32-
self._cache: "TensorRingBuffer" | None = None
75+
return_format = self._get_return_format()
76+
return_shapes = return_format.values() if isinstance(return_format, dict) else (return_format,)
77+
tensor_size = 0
78+
for shape in return_shapes:
79+
data_size = np.prod(shape)
80+
self._shape_indices.append((tensor_size, tensor_size + data_size))
81+
tensor_size += data_size
82+
83+
self._cache_size = self._get_cache_length() * tensor_size
84+
self._shared_metadata.cache_sizes.append(self._cache_size)
85+
86+
self._cache_idx: int = -1 # initialized by SensorManager during build
3387

3488
# =============================== implementable methods ===============================
3589

@@ -71,12 +125,19 @@ def _update_shared_ground_truth_cache(
71125

72126
@classmethod
73127
def _update_shared_cache(
74-
cls, shared_metadata: dict[str, Any], shared_ground_truth_cache: torch.Tensor, shared_cache: "TensorRingBuffer"
128+
cls,
129+
shared_metadata: dict[str, Any],
130+
shared_ground_truth_cache: torch.Tensor,
131+
shared_cache: torch.Tensor,
132+
buffered_data: "TensorRingBuffer",
75133
):
76134
"""
77135
Update the shared sensor cache for all sensors of this class using metadata in SensorManager.
136+
137+
The information in shared_cache should be the final measured sensor data after all noise and post-processing.
138+
NOTE: The implementation should include applying the delay using the `_apply_delay_to_shared_cache()` method.
78139
"""
79-
raise NotImplementedError("Sensors must implement `update_shared_cache()`.")
140+
raise NotImplementedError("Sensors must implement `update_shared_cache_with_noise()`.")
80141

81142
@classmethod
82143
def _get_cache_dtype(cls) -> torch.dtype:
@@ -92,19 +153,35 @@ def read(self, envs_idx: List[int] | None = None):
92153
"""
93154
Read the sensor data (with noise applied if applicable).
94155
"""
95-
return self._get_formatted_data(self._cache.get(self._read_delay_steps), envs_idx)
156+
return self._get_formatted_data(self._manager.get_cloned_from_cache(self), envs_idx)
96157

97158
@gs.assert_built
98159
def read_ground_truth(self, envs_idx: List[int] | None = None):
99160
"""
100161
Read the ground truth sensor data (without noise).
101162
"""
102-
return self._get_formatted_data(self._manager.get_cloned_from_ground_truth_cache(self), envs_idx)
163+
return self._get_formatted_data(self._manager.get_cloned_from_cache(self, is_ground_truth=True), envs_idx)
164+
165+
@classmethod
166+
def _apply_delay_to_shared_cache(
167+
self, shared_metadata: SharedSensorMetadata, shared_cache: torch.Tensor, buffered_data: "TensorRingBuffer"
168+
):
169+
"""
170+
Applies the read delay to the shared cache tensor by copying the buffered data at the appropriate index.
171+
"""
172+
idx = 0
173+
for tensor_size, read_delay_step in zip(shared_metadata.cache_sizes, shared_metadata.read_delay_steps):
174+
shared_cache[:, idx : idx + tensor_size] = buffered_data.at(read_delay_step)[:, idx : idx + tensor_size]
175+
idx += tensor_size
103176

104177
def _get_formatted_data(
105178
self, tensor: torch.Tensor, envs_idx: list[int] | None
106179
) -> torch.Tensor | dict[str, torch.Tensor]:
107-
# Note: This method does not clone the data tensor, it should have been cloned by the caller.
180+
"""
181+
Formats the flattened cache tensor into a dict of tensors using the format specified in `_get_return_format()`.
182+
183+
NOTE: This method does not clone the data tensor, it should have been cloned by the caller.
184+
"""
108185

109186
if envs_idx is None:
110187
envs_idx = self._manager._sim._scene._envs_idx

0 commit comments

Comments
 (0)