Skip to content

Commit 42be9cb

Browse files
committed
Simplify batching support.
1 parent 843656d commit 42be9cb

File tree

2 files changed

+23
-27
lines changed

2 files changed

+23
-27
lines changed

genesis/sensors/imu.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ class IMUOptions(SensorOptions):
2525
IMU sensor returns the linear acceleration (accelerometer) and angular velocity (gyroscope)
2626
of the associated entity link.
2727
28+
Note
29+
----
30+
Accelerometers return the so-called classical linear acceleration in local frame minus gravity.
31+
2832
Parameters
2933
----------
3034
entity_idx : int
@@ -88,9 +92,10 @@ def build(self):
8892
[self._shared_metadata.offsets_pos, torch.tensor([self._options.pos_offset], dtype=gs.tc_float)]
8993
)
9094

91-
quat_tensor = torch.tensor(euler_to_quat(self._options.euler_offset), dtype=gs.tc_float)
92-
quat_tensor = quat_tensor.view(1, 1, 4).expand(self._manager._sim._B, 1, 4)
93-
self._shared_metadata.offsets_quat = torch.cat([self._shared_metadata.offsets_quat, quat_tensor], dim=1)
95+
quat_tensor = torch.tensor(euler_to_quat(self._options.euler_offset), dtype=gs.tc_float).unsqueeze(0)
96+
if self._shared_metadata.solver.n_envs > 0:
97+
quat_tensor = quat_tensor.unsqueeze(0).expand((self._manager._sim._B, 1, 4))
98+
self._shared_metadata.offsets_quat = torch.cat([self._shared_metadata.offsets_quat, quat_tensor], dim=-2)
9499

95100
self._shared_metadata.acc_bias = torch.cat(
96101
[self._shared_metadata.acc_bias, torch.tensor([self._options.accelerometer_bias], dtype=gs.tc_float)]
@@ -119,29 +124,20 @@ def _update_shared_ground_truth_cache(
119124
quats = shared_metadata.solver.get_links_quat(links_idx=shared_metadata.links_idx)
120125
acc = shared_metadata.solver.get_links_acc(links_idx=shared_metadata.links_idx)
121126
ang = shared_metadata.solver.get_links_ang(links_idx=shared_metadata.links_idx)
122-
if shared_metadata.solver.n_envs == 0:
123-
gravity = gravity.unsqueeze(0)
124-
quats = quats.unsqueeze(0)
125-
acc = acc.unsqueeze(0)
126-
ang = ang.unsqueeze(0)
127127

128128
offset_quats = transform_quat_by_quat(quats, shared_metadata.offsets_quat)
129129

130-
# acc/ang shape: (B, n_links, 3)
130+
# acc/ang shape: (B, n_imus, 3)
131131
local_acc = inv_transform_by_trans_quat(acc, shared_metadata.offsets_pos, offset_quats)
132132
local_ang = inv_transform_by_trans_quat(ang, shared_metadata.offsets_pos, offset_quats)
133133

134-
local_acc = local_acc - gravity.unsqueeze(1).expand(-1, local_acc.shape[1], -1)
134+
*batch_size, n_imus, _ = local_acc.shape
135+
local_acc = local_acc - gravity.unsqueeze(-2).expand((*batch_size, n_imus, -1))
135136

136-
# cache shape: (B, n_links * 6)
137-
batch_size, n_links, _ = local_acc.shape
138-
strided_ground_truth_cache = torch.as_strided(
139-
shared_ground_truth_cache,
140-
size=(batch_size, n_links, 2, 3),
141-
stride=(n_links * 6, 6, 3, 1),
142-
)
143-
strided_ground_truth_cache[:, :, 0, :].copy_(local_acc)
144-
strided_ground_truth_cache[:, :, 1, :].copy_(local_ang)
137+
# cache shape: (B, n_imus * 6)
138+
strided_ground_truth_cache = shared_ground_truth_cache.reshape((*batch_size, n_imus, 2, 3))
139+
strided_ground_truth_cache[..., 0, :].copy_(local_acc)
140+
strided_ground_truth_cache[..., 1, :].copy_(local_ang)
145141

146142
@classmethod
147143
def _update_shared_cache(
@@ -154,20 +150,19 @@ def _update_shared_cache(
154150
"""
155151
Update the current measured sensor data for all IMU sensors.
156152
157-
NOTE: `buffered_data` contains the history of ground truth cache, and noise/bias is only applied to the current
153+
Note
154+
----
155+
`buffered_data` contains the history of ground truth cache, and noise/bias is only applied to the current
158156
sensor readout `shared_cache`, not the whole buffer.
159157
"""
160158
buffered_data.append(shared_ground_truth_cache)
161159
cls._apply_delay_to_shared_cache(shared_metadata, shared_cache, buffered_data)
162160

163161
# add bias to the shared_cache
164-
strided_shared_cache = torch.as_strided(
165-
shared_cache,
166-
size=(shared_cache.shape[0], shared_metadata.acc_bias.shape[0], 2, 3),
167-
stride=(shared_cache.shape[1], 6, 3, 1),
168-
)
169-
strided_shared_cache[:, :, 0, :] += shared_metadata.acc_bias
170-
strided_shared_cache[:, :, 1, :] += shared_metadata.ang_bias
162+
*batch_size, n_imus, _ = shared_metadata.offsets_quat.shape
163+
strided_shared_cache = shared_cache.reshape((*batch_size, n_imus, 2, 3))
164+
strided_shared_cache[..., 0, :] += shared_metadata.acc_bias
165+
strided_shared_cache[..., 1, :] += shared_metadata.ang_bias
171166

172167
@classmethod
173168
def _get_cache_dtype(cls) -> torch.dtype:

tests/test_sensors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def test_imu_sensor(show_viewer):
2020
gravity=(0.0, 0.0, GRAVITY),
2121
),
2222
show_viewer=show_viewer,
23+
show_FPS=False,
2324
)
2425

2526
scene.add_entity(gs.morphs.Plane())

0 commit comments

Comments
 (0)