@@ -25,6 +25,10 @@ class IMUOptions(SensorOptions):
2525 IMU sensor returns the linear acceleration (accelerometer) and angular velocity (gyroscope)
2626 of the associated entity link.
2727
28+ Note
29+ ----
30+ Accelerometers return the so-called classical linear acceleration in local frame minus gravity.
31+
2832 Parameters
2933 ----------
3034 entity_idx : int
@@ -88,9 +92,10 @@ def build(self):
8892 [self ._shared_metadata .offsets_pos , torch .tensor ([self ._options .pos_offset ], dtype = gs .tc_float )]
8993 )
9094
91- quat_tensor = torch .tensor (euler_to_quat (self ._options .euler_offset ), dtype = gs .tc_float )
92- quat_tensor = quat_tensor .view (1 , 1 , 4 ).expand (self ._manager ._sim ._B , 1 , 4 )
93- self ._shared_metadata .offsets_quat = torch .cat ([self ._shared_metadata .offsets_quat , quat_tensor ], dim = 1 )
95+ quat_tensor = torch .tensor (euler_to_quat (self ._options .euler_offset ), dtype = gs .tc_float ).unsqueeze (0 )
96+ if self ._shared_metadata .solver .n_envs > 0 :
97+ quat_tensor = quat_tensor .unsqueeze (0 ).expand ((self ._manager ._sim ._B , 1 , 4 ))
98+ self ._shared_metadata .offsets_quat = torch .cat ([self ._shared_metadata .offsets_quat , quat_tensor ], dim = - 2 )
9499
95100 self ._shared_metadata .acc_bias = torch .cat (
96101 [self ._shared_metadata .acc_bias , torch .tensor ([self ._options .accelerometer_bias ], dtype = gs .tc_float )]
@@ -119,29 +124,20 @@ def _update_shared_ground_truth_cache(
119124 quats = shared_metadata .solver .get_links_quat (links_idx = shared_metadata .links_idx )
120125 acc = shared_metadata .solver .get_links_acc (links_idx = shared_metadata .links_idx )
121126 ang = shared_metadata .solver .get_links_ang (links_idx = shared_metadata .links_idx )
122- if shared_metadata .solver .n_envs == 0 :
123- gravity = gravity .unsqueeze (0 )
124- quats = quats .unsqueeze (0 )
125- acc = acc .unsqueeze (0 )
126- ang = ang .unsqueeze (0 )
127127
128128 offset_quats = transform_quat_by_quat (quats , shared_metadata .offsets_quat )
129129
130- # acc/ang shape: (B, n_links , 3)
130+ # acc/ang shape: (B, n_imus , 3)
131131 local_acc = inv_transform_by_trans_quat (acc , shared_metadata .offsets_pos , offset_quats )
132132 local_ang = inv_transform_by_trans_quat (ang , shared_metadata .offsets_pos , offset_quats )
133133
134- local_acc = local_acc - gravity .unsqueeze (1 ).expand (- 1 , local_acc .shape [1 ], - 1 )
134+ * batch_size , n_imus , _ = local_acc .shape
135+ local_acc = local_acc - gravity .unsqueeze (- 2 ).expand ((* batch_size , n_imus , - 1 ))
135136
136- # cache shape: (B, n_links * 6)
137- batch_size , n_links , _ = local_acc .shape
138- strided_ground_truth_cache = torch .as_strided (
139- shared_ground_truth_cache ,
140- size = (batch_size , n_links , 2 , 3 ),
141- stride = (n_links * 6 , 6 , 3 , 1 ),
142- )
143- strided_ground_truth_cache [:, :, 0 , :].copy_ (local_acc )
144- strided_ground_truth_cache [:, :, 1 , :].copy_ (local_ang )
137+ # cache shape: (B, n_imus * 6)
138+ strided_ground_truth_cache = shared_ground_truth_cache .reshape ((* batch_size , n_imus , 2 , 3 ))
139+ strided_ground_truth_cache [..., 0 , :].copy_ (local_acc )
140+ strided_ground_truth_cache [..., 1 , :].copy_ (local_ang )
145141
146142 @classmethod
147143 def _update_shared_cache (
@@ -154,20 +150,19 @@ def _update_shared_cache(
154150 """
155151 Update the current measured sensor data for all IMU sensors.
156152
157- NOTE: `buffered_data` contains the history of ground truth cache, and noise/bias is only applied to the current
153+ Note
154+ ----
155+ `buffered_data` contains the history of ground truth cache, and noise/bias is only applied to the current
158156 sensor readout `shared_cache`, not the whole buffer.
159157 """
160158 buffered_data .append (shared_ground_truth_cache )
161159 cls ._apply_delay_to_shared_cache (shared_metadata , shared_cache , buffered_data )
162160
163161 # add bias to the shared_cache
164- strided_shared_cache = torch .as_strided (
165- shared_cache ,
166- size = (shared_cache .shape [0 ], shared_metadata .acc_bias .shape [0 ], 2 , 3 ),
167- stride = (shared_cache .shape [1 ], 6 , 3 , 1 ),
168- )
169- strided_shared_cache [:, :, 0 , :] += shared_metadata .acc_bias
170- strided_shared_cache [:, :, 1 , :] += shared_metadata .ang_bias
162+ * batch_size , n_imus , _ = shared_metadata .offsets_quat .shape
163+ strided_shared_cache = shared_cache .reshape ((* batch_size , n_imus , 2 , 3 ))
164+ strided_shared_cache [..., 0 , :] += shared_metadata .acc_bias
165+ strided_shared_cache [..., 1 , :] += shared_metadata .ang_bias
171166
172167 @classmethod
173168 def _get_cache_dtype (cls ) -> torch .dtype :
0 commit comments