diff --git a/source/isaaclab/isaaclab/sim/prims/xform_prim.py b/source/isaaclab/isaaclab/sim/prims/xform_prim.py index 607f864453f..3eca854b2b9 100644 --- a/source/isaaclab/isaaclab/sim/prims/xform_prim.py +++ b/source/isaaclab/isaaclab/sim/prims/xform_prim.py @@ -228,6 +228,24 @@ def _to_numpy(self, value: torch.Tensor | np.ndarray | Sequence[float] | None) - else: return np.array(value) + @staticmethod + def _world_to_local_tq(prim: Usd.Prim, world_t: Gf.Vec3d, world_q: Gf.Quatd) -> tuple[Gf.Vec3d, Gf.Quatd]: + """Convert desired world (t,q) into local (t,q) wrt prim's parent.""" + time = Usd.TimeCode.Default() + parent = prim.GetParent() + if parent and parent.IsValid(): + parent_w = UsdGeom.Xformable(parent).ComputeLocalToWorldTransform(time) + else: + parent_w = Gf.Matrix4d(1.0) + + world_m = Gf.Matrix4d(1.0) + world_m.SetRotate(world_q) + world_m.SetTranslateOnly(world_t) + + local_m = parent_w.GetInverse() * world_m + local_tf = Gf.Transform(local_m) + return local_tf.GetTranslation(), local_tf.GetRotation().GetQuat() + def set_world_poses( self, positions: torch.Tensor | np.ndarray | Sequence[float] | None = None, @@ -282,19 +300,26 @@ def set_world_poses( orient_op = xformable.AddXformOp(UsdGeom.XformOp.TypeOrient, UsdGeom.XformOp.PrecisionDouble) # Set position - if pos_np is not None: - # Convert numpy values to Python floats for USD - translate_op.Set(Gf.Vec3d(float(pos_np[idx, 0]), float(pos_np[idx, 1]), float(pos_np[idx, 2]))) - - # Set orientation - if orient_np is not None: - # Convert numpy values to Python floats for USD - w = float(orient_np[idx, 0]) - x = float(orient_np[idx, 1]) - y = float(orient_np[idx, 2]) - z = float(orient_np[idx, 3]) - quat = Gf.Quatd(w, Gf.Vec3d(x, y, z)) - orient_op.Set(quat) + if pos_np is not None or orient_np is not None: + current_world = xformable.ComputeLocalToWorldTransform(Usd.TimeCode.Default()) + + if pos_np is not None: + world_t = Gf.Vec3d(float(pos_np[idx, 0]), float(pos_np[idx, 1]), float(pos_np[idx, 2])) + else: + world_t = current_world.ExtractTranslation() + + if orient_np is not None: + w = float(orient_np[idx, 0]) + x = float(orient_np[idx, 1]) + y = float(orient_np[idx, 2]) + z = float(orient_np[idx, 3]) + world_q = Gf.Quatd(w, Gf.Vec3d(x, y, z)) + else: + world_q = current_world.ExtractRotation().GetQuat() + + local_t, local_q = self._world_to_local_tq(prim, world_t, world_q) + translate_op.Set(local_t) + orient_op.Set(local_q) def set_local_poses( self, @@ -310,7 +335,59 @@ def set_local_poses( indices: Indices of prims to update. If None, all prims are updated. """ # For local poses, we use the same method since USD xform ops are inherently local - self.set_world_poses(positions=translations, orientations=orientations, indices=indices) + # Convert to numpy + trans_np = self._to_numpy(translations) + orient_np = self._to_numpy(orientations) + indices_np = self._to_numpy(indices) + + # Determine which prims to update + if indices_np is None: + prim_indices = range(self._count) + else: + prim_indices = indices_np.astype(int) + + # Broadcast if needed + if trans_np is not None: + if trans_np.ndim == 1: + trans_np = np.tile(trans_np, (len(prim_indices), 1)) + + if orient_np is not None: + if orient_np.ndim == 1: + orient_np = np.tile(orient_np, (len(prim_indices), 1)) + + # Update each prim + for idx, prim_idx in enumerate(prim_indices): + prim = self._prims[prim_idx] + xformable = UsdGeom.Xformable(prim) + + # Get or create the translate op + translate_attr = prim.GetAttribute("xformOp:translate") + if translate_attr: + translate_op = UsdGeom.XformOp(translate_attr) + else: + translate_op = xformable.AddXformOp(UsdGeom.XformOp.TypeTranslate, UsdGeom.XformOp.PrecisionDouble) + + # Get or create the orient op + orient_attr = prim.GetAttribute("xformOp:orient") + if orient_attr: + orient_op = UsdGeom.XformOp(orient_attr) + else: + orient_op = xformable.AddXformOp(UsdGeom.XformOp.TypeOrient, UsdGeom.XformOp.PrecisionDouble) + + # Set translation + if trans_np is not None: + # Convert numpy values to Python floats for USD + translate_op.Set(Gf.Vec3d(float(trans_np[idx, 0]), float(trans_np[idx, 1]), float(trans_np[idx, 2]))) + + # Set orientation + if orient_np is not None: + # Convert numpy values to Python floats for USD + w = float(orient_np[idx, 0]) + x = float(orient_np[idx, 1]) + y = float(orient_np[idx, 2]) + z = float(orient_np[idx, 3]) + quat = Gf.Quatd(w, Gf.Vec3d(x, y, z)) + orient_op.Set(quat) def set_local_scales( self,