diff --git a/source/isaaclab/isaaclab/utils/backend_utils.py b/source/isaaclab/isaaclab/utils/backend_utils.py index 0ce11c09205..aacdaadb0c4 100644 --- a/source/isaaclab/isaaclab/utils/backend_utils.py +++ b/source/isaaclab/isaaclab/utils/backend_utils.py @@ -36,6 +36,7 @@ def register(cls, name: str, sub_class) -> None: def __new__(cls, *args, **kwargs): """Create a new instance of an implementation based on the backend.""" + # TODO: Make the backend configurable. backend = "newton" if cls == FactoryBase: diff --git a/source/isaaclab/isaaclab/utils/buffers/timestamped_wp_buffer.py b/source/isaaclab/isaaclab/utils/buffers/timestamped_wp_buffer.py index db9eca0fd05..42bf0824247 100644 --- a/source/isaaclab/isaaclab/utils/buffers/timestamped_wp_buffer.py +++ b/source/isaaclab/isaaclab/utils/buffers/timestamped_wp_buffer.py @@ -34,10 +34,13 @@ class TimestampedWarpBuffer: dtype: type | None = None """Dtype of the data stored in the buffer. Default is None, indicating that the buffer is empty.""" + device: str = "cuda:0" + """Device of the data stored in the buffer. Default is "cuda:0", indicating that the buffer is on the first GPU.""" + def __post_init__(self): if self.shape is None: raise ValueError("Shape of the data stored in the buffer is not set.") if self.dtype is None: raise ValueError("Dtype of the data stored in the buffer is not set.") if self.data is None: - self.data = wp.empty(self.shape, dtype=self.dtype) + self.data = wp.empty(self.shape, dtype=self.dtype, device=self.device) diff --git a/source/isaaclab/isaaclab/utils/warp/utils.py b/source/isaaclab/isaaclab/utils/warp/utils.py new file mode 100644 index 00000000000..aa12f8a522d --- /dev/null +++ b/source/isaaclab/isaaclab/utils/warp/utils.py @@ -0,0 +1,123 @@ +# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import logging +import torch +from collections.abc import Sequence + +import warp as wp + +logger = logging.getLogger(__name__) + +## +# Frontend conversions - Torch to Warp. +## + + +# TODO: Perf is atrocious. Need to improve. +# Option 1: Pre-allocate the complete data buffer and fill it with the value. +# Option 2: Create a torch pointer to the warp array and by pass these methods using torch indexing to update +# the warp array. This would save the memory allocation and the generation of the masks. +def make_complete_data_from_torch_single_index( + value: torch.Tensor, + N: int, + ids: Sequence[int] | torch.Tensor | None = None, + dtype: type = wp.float32, + device: str = "cuda:0", +) -> wp.array: + """Converts any Torch frontend data into warp data with single index support. + + Args: + value: The value to convert. Shape is (N,). + N: The number of elements in the value. + ids: The index ids. + dtype: The dtype of the value. + device: The device to use for the conversion. + + Returns: + A warp array. + """ + if ids is None: + # No ids are provided, so we are expecting complete data. + value = wp.from_torch(value, dtype=dtype) + else: + # Create a complete data buffer from scratch + complete = torch.zeros((N, *value.shape[1:]), dtype=torch.float32, device=device) + complete[ids] = value + value = wp.from_torch(complete, dtype=dtype) + return value + + +def make_complete_data_from_torch_dual_index( + value: torch.Tensor, + N: int, + M: int, + first_ids: Sequence[int] | torch.Tensor | None = None, + second_ids: Sequence[int] | torch.Tensor | None = None, + dtype: type = wp.float32, + device: str = "cuda:0", +) -> wp.array: + """Converts any Torch frontend data into warp data with dual index support. + + Args: + value: The value to convert. Shape is (N, M) or (len(first_ids), len(second_ids)). + N: The number of elements in the first dimension. + M: The number of elements in the second dimension. + first_ids: The first index ids. + second_ids: The second index ids. + dtype: The dtype of the value. + device: The device to use for the conversion. + + Returns: + A tuple of warp data with its two masks. + """ + if (first_ids is None) and (second_ids is None): + # No ids are provided, so we are expecting complete data. + value = wp.from_torch(value, dtype=dtype) + else: + # Create a complete data buffer from scratch + complete = torch.zeros((N, M, *value.shape[2:]), dtype=torch.float32, device=device) + # Fill the complete data buffer with the value. + if first_ids is None: + first_ids = slice(None) + if second_ids is None: + second_ids = slice(None) + if first_ids != slice(None) and second_ids != slice(None): + if isinstance(first_ids, list): + first_ids = torch.tensor(first_ids, dtype=torch.int32, device=device) + first_ids = first_ids[:, None] + complete[first_ids, second_ids] = value + value = wp.from_torch(complete, dtype=dtype) + return value + + +def make_masks_from_torch_ids( + N: int, + first_ids: Sequence[int] | torch.Tensor | None = None, + first_mask: wp.array | torch.Tensor | None = None, + device: str = "cuda:0", +) -> wp.array | None: + """Converts any Torch frontend data into warp data with dual index support. + + Args: + value: The value to convert. Shape is (N, M) or (len(first_ids), len(second_ids)). + first_ids: The first index ids. + second_ids: The second index ids. + first_mask: The first index mask. + second_mask: The second index mask. + dtype: The dtype of the value. + device: The device to use for the conversion. + + Returns: + A tuple of warp data with its two masks. + """ + if (first_ids is not None) and (first_mask is None): + # Create a mask from scratch + first_mask = torch.zeros(N, dtype=torch.bool, device=device) + first_mask[first_ids] = True + first_mask = wp.from_torch(first_mask, dtype=wp.bool) + elif isinstance(first_mask, torch.Tensor): + first_mask = wp.from_torch(first_mask, dtype=wp.bool) + return first_mask diff --git a/source/isaaclab_newton/isaaclab_newton/assets/articulation/articulation.py b/source/isaaclab_newton/isaaclab_newton/assets/articulation/articulation.py index 1ba4fa877d6..e11cc29b267 100644 --- a/source/isaaclab_newton/isaaclab_newton/assets/articulation/articulation.py +++ b/source/isaaclab_newton/isaaclab_newton/assets/articulation/articulation.py @@ -18,13 +18,11 @@ import warp as wp from isaaclab_newton.actuators import ActuatorBase, ImplicitActuator from isaaclab_newton.assets.articulation.articulation_data import ArticulationData +from isaaclab_newton.assets.utils.shared import find_bodies, find_joints from isaaclab_newton.kernels import ( - generate_mask_from_ids, - split_state_to_pose, - split_state_to_velocity, + project_link_velocity_to_com_frame_masked_root, + split_state_to_pose_and_velocity, transform_CoM_pose_to_link_frame_masked_root, - update_joint_limits, - update_joint_limits_value_vec2f, update_soft_joint_pos_limits, update_wrench_array_with_force, update_wrench_array_with_torque, @@ -39,7 +37,7 @@ import isaaclab.utils.string as string_utils from isaaclab.assets.articulation.base_articulation import BaseArticulation from isaaclab.sim._impl.newton_manager import NewtonManager -from isaaclab.utils.helpers import deprecated, warn_overhead_cost +from isaaclab.utils.helpers import deprecated from isaaclab.utils.warp.update_kernels import ( update_array1D_with_array1D_masked, update_array1D_with_value, @@ -48,6 +46,11 @@ update_array2D_with_array2D_masked, update_array2D_with_value_masked, ) +from isaaclab.utils.warp.utils import ( + make_complete_data_from_torch_dual_index, + make_complete_data_from_torch_single_index, + make_masks_from_torch_ids, +) if TYPE_CHECKING: from isaaclab.actuators.actuator_cfg import ActuatorBaseCfg @@ -165,6 +168,7 @@ def num_bodies(self) -> int: """Number of bodies in articulation.""" return self._root_view.link_count + @property def num_shapes_per_body(self) -> list[int]: """Number of collision shapes per body in the articulation. @@ -177,7 +181,7 @@ def num_shapes_per_body(self) -> list[int]: """ if not hasattr(self, "_num_shapes_per_body"): self._num_shapes_per_body = [] - for shapes in self._root_newton_view.body_shapes: + for shapes in self._root_view.body_shapes: self._num_shapes_per_body.append(len(shapes)) return self._num_shapes_per_body @@ -217,6 +221,8 @@ def root_newton_model(self) -> Model: """Newton model for the asset.""" return self._root_view.model + # TODO: Plug-in the Wrench code from Isaac Lab once the PR gets in. + """ Operations. """ @@ -230,7 +236,7 @@ def reset(self, ids: Sequence[int] | None = None, mask: wp.array | None = None): if isinstance(mask, torch.Tensor): mask = wp.from_torch(mask, dtype=wp.bool) else: - mask = self._data.ALL_ENV_MASK + mask = self._data.ALL_BODY_MASK # reset external wrench wp.launch( update_array2D_with_value_masked, @@ -238,8 +244,8 @@ def reset(self, ids: Sequence[int] | None = None, mask: wp.array | None = None): inputs=[ wp.spatial_vectorf(0.0, 0.0, 0.0, 0.0, 0.0, 0.0), self._data._sim_bind_body_external_wrench, - mask, self._data.ALL_ENV_MASK, + mask, ], ) @@ -269,129 +275,6 @@ def write_data_to_sim(self): def update(self, dt: float): self._data.update(dt) - """ - Frontend conversions - Torch to Warp. - """ - - # FIXME: Move this to utils / helpers. - def _torch_to_warp_single_index( - self, - value: torch.Tensor, - N: int, - ids: Sequence[int] | None = None, - mask: torch.Tensor | None = None, - dtype: type = wp.float32, - ) -> tuple[wp.array, wp.array | None]: - """Converts any Torch frontend data into warp data with single index support. - - Args: - value: The value to convert. Shape is (N,). - N: The number of elements in the value. - ids: The index ids. - mask: The index mask. - dtype: The dtype of the value. - - Returns: - A tuple of warp data with its mask. - """ - if mask is None: - if ids is not None: - # Create a mask from scratch - env_mask = torch.zeros(N, dtype=torch.bool, device=self.device) - env_mask[ids] = True - env_mask = wp.from_torch(env_mask, dtype=wp.bool) - # Create a complete data buffer from scratch - complete = torch.zeros((N, *value.shape[1:]), dtype=value.dtype, device=self.device) - complete[ids] = value - value = wp.from_torch(complete, dtype=dtype) - else: - value = wp.from_torch(value, dtype=dtype) - else: - if ids is not None: - warnings.warn( - "ids is not None, but mask is provided. Ignoring ids. Please make sure you are providing complete" - " data buffers.", - UserWarning, - ) - env_mask = wp.from_torch(mask, dtype=wp.bool) - value = wp.from_torch(value, dtype=dtype) - return value, env_mask - - # FIXME: Move this to utils / helpers. - def _torch_to_warp_dual_index( - self, - value: torch.Tensor, - N: int, - M: int, - first_ids: Sequence[int] | None = None, - second_ids: Sequence[int] | None = None, - first_mask: torch.Tensor | None = None, - second_mask: torch.Tensor | None = None, - dtype: type = wp.float32, - ) -> tuple[wp.array, wp.array | None, wp.array | None]: - """Converts any Torch frontend data into warp data with dual index support. - - Args: - value: The value to convert. Shape is (N, M) or (len(first_ids), len(second_ids)). - first_ids: The first index ids. - second_ids: The second index ids. - first_mask: The first index mask. - second_mask: The second index mask. - dtype: The dtype of the value. - - Returns: - A tuple of warp data with its two masks. - """ - if first_mask is None: - if (first_ids is not None) or (second_ids is not None): - # If we are provided with either first_ids or second_ids, we need to create a mask from scratch and - # we are expecting partial values. - if first_ids is not None: - # Create a mask from scratch - first_mask = torch.zeros(N, dtype=torch.bool, device=self.device) - first_mask[first_ids] = True - first_mask = wp.from_torch(first_mask, dtype=wp.bool) - else: - first_ids = slice(None) - if second_ids is not None: - # Create a mask from scratch - second_mask = torch.zeros(M, dtype=torch.bool, device=self.device) - second_mask[second_ids] = True - second_mask = wp.from_torch(second_mask, dtype=wp.bool) - else: - second_ids = slice(None) - if first_ids != slice(None) and second_ids != slice(None): - first_ids = first_ids[:, None] - - # Create a complete data buffer from scratch - if dtype == wp.vec3f: - complete_value = torch.zeros(N, M, 3, dtype=value.dtype, device=self.device) - elif dtype == wp.mat33f: - complete_value = torch.zeros(N, M, 3, 3, dtype=value.dtype, device=self.device) - else: - complete_value = torch.zeros(N, M, dtype=value.dtype, device=self.device) - complete_value[first_ids, second_ids] = value - value = wp.from_torch(complete_value, dtype=dtype) - elif second_mask is not None: - second_mask = wp.from_torch(second_mask, dtype=wp.bool) - value = wp.from_torch(value, dtype=dtype) - else: - value = wp.from_torch(value, dtype=dtype) - else: - if (first_ids is not None) or (second_ids is not None): - warnings.warn( - "Mask and ids are provided. Ignoring ids. Please make sure you are providing complete data" - " buffers.", - UserWarning, - ) - first_mask = wp.from_torch(first_mask, dtype=wp.bool) - if second_mask is not None: - second_mask = wp.from_torch(second_mask, dtype=wp.bool) - else: - value = wp.from_torch(value, dtype=dtype) - - return value, first_mask, second_mask - """ Operations - Finders. """ @@ -411,7 +294,7 @@ def find_bodies( Returns: A tuple of lists containing the body mask, names, and indices. """ - return self._find_bodies(name_keys, preserve_order) + return find_bodies(self.body_names, name_keys, preserve_order, self.device) def find_joints( self, name_keys: str | Sequence[str], joint_subset: list[str] | None = None, preserve_order: bool = False @@ -430,7 +313,7 @@ def find_joints( Returns: A tuple of lists containing the joint mask, names, and indices. """ - return self._find_joints(name_keys, joint_subset, preserve_order) + return find_joints(self.joint_names, name_keys, joint_subset, preserve_order, self.device) def find_fixed_tendons( self, name_keys: str | Sequence[str], tendon_subsets: list[str] | None = None, preserve_order: bool = False @@ -494,9 +377,10 @@ def write_root_state_to_sim( """ # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. if isinstance(root_state, torch.Tensor): - root_state, env_mask = self._torch_to_warp_single_index( - root_state, self.num_instances, env_ids, env_mask, dtype=vec13f + root_state = make_complete_data_from_torch_single_index( + root_state, self.num_instances, ids=env_ids, dtype=vec13f, device=self.device ) + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) # solve for None masks if env_mask is None: env_mask = self._data.ALL_ENV_MASK @@ -525,9 +409,10 @@ def write_root_com_state_to_sim( """ # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. if isinstance(root_state, torch.Tensor): - root_state, env_mask = self._torch_to_warp_single_index( - root_state, self.num_instances, env_ids, env_mask, dtype=vec13f + root_state = make_complete_data_from_torch_single_index( + root_state, self.num_instances, ids=env_ids, dtype=vec13f, device=self.device ) + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) if env_mask is None: env_mask = self._data.ALL_ENV_MASK # split the state into pose and velocity @@ -555,9 +440,10 @@ def write_root_link_state_to_sim( """ # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. if isinstance(root_state, torch.Tensor): - root_state, env_mask = self._torch_to_warp_single_index( - root_state, self.num_instances, env_ids, env_mask, dtype=vec13f + root_state = make_complete_data_from_torch_single_index( + root_state, self.num_instances, ids=env_ids, dtype=vec13f, device=self.device ) + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) if env_mask is None: env_mask = self._data.ALL_ENV_MASK # split the state into pose and velocity @@ -601,9 +487,10 @@ def write_root_link_pose_to_sim( """ # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. if isinstance(pose, torch.Tensor): - pose, env_mask = self._torch_to_warp_single_index( - pose, self.num_instances, env_ids, env_mask, dtype=wp.transformf + pose = make_complete_data_from_torch_single_index( + pose, self.num_instances, ids=env_ids, dtype=wp.transformf, device=self.device ) + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) # solve for None masks if env_mask is None: env_mask = self._data.ALL_ENV_MASK @@ -630,25 +517,29 @@ def write_root_com_pose_to_sim( """ # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. if isinstance(root_pose, torch.Tensor): - root_pose, env_mask = self._torch_to_warp_single_index( - root_pose, self.num_instances, env_ids, env_mask, dtype=wp.transformf + root_pose = make_complete_data_from_torch_single_index( + root_pose, self.num_instances, ids=env_ids, dtype=wp.transformf, device=self.device ) + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) # solve for None masks if env_mask is None: env_mask = self._data.ALL_ENV_MASK # Write to Newton using warp - self._update_array_with_array_masked(root_pose, self._data.root_com_pose_w.data, env_mask, self.num_instances) + self._update_array_with_array_masked(root_pose, self._data._root_com_pose_w.data, env_mask, self.num_instances) # set link frame poses wp.launch( transform_CoM_pose_to_link_frame_masked_root, dim=self.num_instances, + device=self.device, inputs=[ - self._data.root_com_pose_w, + self._data._root_com_pose_w.data, self._data.body_com_pos_b, self._data.root_link_pose_w, env_mask, ], ) + # Force update the timestamp + self._data._root_com_pose_w.timestamp = self._data._sim_timestamp def write_root_velocity_to_sim( self, @@ -686,9 +577,10 @@ def write_root_com_velocity_to_sim( """ # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. if isinstance(root_velocity, torch.Tensor): - root_velocity, env_mask = self._torch_to_warp_single_index( - root_velocity, self.num_instances, env_ids, env_mask, dtype=wp.spatial_vectorf + root_velocity = make_complete_data_from_torch_single_index( + root_velocity, self.num_instances, ids=env_ids, dtype=wp.spatial_vectorf, device=self.device ) + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) # solve for None masks if env_mask is None: env_mask = self._data.ALL_ENV_MASK @@ -714,14 +606,32 @@ def write_root_link_velocity_to_sim( """ # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. if isinstance(root_velocity, torch.Tensor): - root_velocity, env_mask = self._torch_to_warp_single_index( - root_velocity, self.num_instances, env_ids, env_mask, dtype=wp.spatial_vectorf + root_velocity = make_complete_data_from_torch_single_index( + root_velocity, self.num_instances, ids=env_ids, dtype=wp.spatial_vectorf, device=self.device ) + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) # solve for None masks if env_mask is None: env_mask = self._data.ALL_ENV_MASK + # update the root link velocity + self._update_array_with_array_masked( + root_velocity, self._data._root_link_vel_w.data, env_mask, self.num_instances + ) # set into simulation - self._write_root_link_velocity_to_sim(root_velocity, env_mask) + wp.launch( + project_link_velocity_to_com_frame_masked_root, + dim=self.num_instances, + device=self.device, + inputs=[ + root_velocity, + self._data.root_link_pose_w, + self._data.body_com_pos_b, + self._data.root_com_vel_w, + env_mask, + ], + ) + # Force update the timestamp + self._data._root_link_vel_w.timestamp = self._data._sim_timestamp # invalidate the derived velocities self._data._root_link_vel_b.timestamp = -1.0 self._data._root_com_vel_b.timestamp = -1.0 @@ -747,31 +657,27 @@ def write_joint_state_to_sim( """ # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. if isinstance(position, torch.Tensor): - position, _, _ = self._torch_to_warp_dual_index( + position = make_complete_data_from_torch_dual_index( position, self.num_instances, self.num_joints, env_ids, joint_ids, - env_mask, - joint_mask, dtype=wp.float32, + device=self.device, ) - velocity, env_mask, joint_mask = self._torch_to_warp_dual_index( + velocity = make_complete_data_from_torch_dual_index( velocity, self.num_instances, self.num_joints, env_ids, joint_ids, - env_mask, - joint_mask, dtype=wp.float32, + device=self.device, ) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK - if joint_mask is None: - joint_mask = self._data.ALL_JOINT_MASK + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) + joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + # None masks are handled within the kernel. # set into simulation self._update_batched_array_with_batched_array_masked( position, self._data.joint_pos, env_mask, joint_mask, (self.num_instances, self.num_joints) @@ -799,21 +705,18 @@ def write_joint_position_to_sim( """ # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. if isinstance(position, torch.Tensor): - position, env_mask, joint_mask = self._torch_to_warp_dual_index( + position = make_complete_data_from_torch_dual_index( position, self.num_instances, self.num_joints, env_ids, joint_ids, - env_mask, - joint_mask, dtype=wp.float32, + device=self.device, ) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK - if joint_mask is None: - joint_mask = self._data.ALL_JOINT_MASK + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) + joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + # None masks are handled within the kernel. # set into simulation self._update_batched_array_with_batched_array_masked( position, self._data.joint_pos, env_mask, joint_mask, (self.num_instances, self.num_joints) @@ -838,21 +741,18 @@ def write_joint_velocity_to_sim( """ # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. if isinstance(velocity, torch.Tensor): - velocity, env_mask, joint_mask = self._torch_to_warp_dual_index( + velocity = make_complete_data_from_torch_dual_index( velocity, self.num_instances, self.num_joints, env_ids, joint_ids, - env_mask, - joint_mask, dtype=wp.float32, + device=self.device, ) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK - if joint_mask is None: - joint_mask = self._data.ALL_JOINT_MASK + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) + joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + # None masks are handled within the kernel. # set into simulation self._update_batched_array_with_batched_array_masked( velocity, self._data.joint_vel, env_mask, joint_mask, (self.num_instances, self.num_joints) @@ -881,21 +781,18 @@ def write_joint_stiffness_to_sim( """ # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. if isinstance(stiffness, torch.Tensor): - stiffness, env_mask, joint_mask = self._torch_to_warp_dual_index( + stiffness = make_complete_data_from_torch_dual_index( stiffness, self.num_instances, self.num_joints, env_ids, joint_ids, - env_mask, - joint_mask, dtype=wp.float32, + device=self.device, ) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK - if joint_mask is None: - joint_mask = self._data.ALL_JOINT_MASK + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) + joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + # None masks are handled within the kernel. # set into simulation if isinstance(stiffness, float): self._update_batched_array_with_value_masked( @@ -927,14 +824,18 @@ def write_joint_damping_to_sim( """ # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. if isinstance(damping, torch.Tensor): - damping, env_mask, joint_mask = self._torch_to_warp_dual_index( - damping, self.num_instances, self.num_joints, env_ids, joint_ids, env_mask, joint_mask, dtype=wp.float32 + damping = make_complete_data_from_torch_dual_index( + damping, + self.num_instances, + self.num_joints, + env_ids, + joint_ids, + dtype=wp.float32, + device=self.device, ) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK - if joint_mask is None: - joint_mask = self._data.ALL_JOINT_MASK + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) + joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + # None masks are handled within the kernel. # set into simulation if isinstance(damping, float): self._update_batched_array_with_value_masked( @@ -949,8 +850,8 @@ def write_joint_damping_to_sim( def write_joint_position_limit_to_sim( self, - upper_limits: wp.array | float, lower_limits: wp.array | float, + upper_limits: wp.array | float, joint_ids: Sequence[int] | None = None, env_ids: Sequence[int] | None = None, joint_mask: wp.array | None = None, @@ -959,8 +860,8 @@ def write_joint_position_limit_to_sim( """Write joint position limits into the simulation. Args: - upper_limits: Joint upper limits. Shape is (len(env_ids), len(joint_ids)) or (num_instances, num_joints). lower_limits: Joint lower limits. Shape is (num_instances, num_joints). + upper_limits: Joint upper limits. Shape is (len(env_ids), len(joint_ids)) or (num_instances, num_joints). joint_ids: The joint indices to set the targets for. Defaults to None (all joints). env_ids: The environment indices to set the targets for. Defaults to None (all environments). joint_mask: The joint mask. Shape is (num_joints). @@ -968,33 +869,30 @@ def write_joint_position_limit_to_sim( """ # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. if isinstance(upper_limits, torch.Tensor): - upper_limits, _, _ = self._torch_to_warp_dual_index( + upper_limits = make_complete_data_from_torch_dual_index( upper_limits, self.num_instances, self.num_joints, env_ids, joint_ids, - env_mask, - joint_mask, dtype=wp.float32, + device=self.device, ) - lower_limits, env_mask, joint_mask = self._torch_to_warp_dual_index( + if isinstance(lower_limits, torch.Tensor): + lower_limits = make_complete_data_from_torch_dual_index( lower_limits, self.num_instances, self.num_joints, env_ids, joint_ids, - env_mask, - joint_mask, dtype=wp.float32, + device=self.device, ) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK - if joint_mask is None: - joint_mask = self._data.ALL_JOINT_MASK + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) + joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + # None masks are handled within the kernel. # set into simulation - self._write_joint_position_limit_to_sim(upper_limits, lower_limits, joint_mask, env_mask) + self._write_joint_position_limit_to_sim(lower_limits, upper_limits, joint_mask, env_mask) # tell the physics engine that some of the joint properties have been updated NewtonManager.add_model_change(SolverNotifyFlags.JOINT_DOF_PROPERTIES) @@ -1026,14 +924,18 @@ def write_joint_velocity_limit_to_sim( warnings.warn("write_joint_velocity_limit_to_sim is ignored by the solver when using Mujoco.") # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. if isinstance(limits, torch.Tensor): - limits, env_mask, joint_mask = self._torch_to_warp_dual_index( - limits, self.num_instances, self.num_joints, env_ids, joint_ids, env_mask, joint_mask, dtype=wp.float32 + limits = make_complete_data_from_torch_dual_index( + limits, + self.num_instances, + self.num_joints, + env_ids, + joint_ids, + dtype=wp.float32, + device=self.device, ) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK - if joint_mask is None: - joint_mask = self._data.ALL_JOINT_MASK + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) + joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + # None masks are handled within the kernel. # set into simulation if isinstance(limits, float): self._update_batched_array_with_value_masked( @@ -1068,14 +970,18 @@ def write_joint_effort_limit_to_sim( """ # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. if isinstance(limits, torch.Tensor): - limits, env_mask, joint_mask = self._torch_to_warp_dual_index( - limits, self.num_instances, self.num_joints, env_ids, joint_ids, env_mask, joint_mask, dtype=wp.float32 + limits = make_complete_data_from_torch_dual_index( + limits, + self.num_instances, + self.num_joints, + env_ids, + joint_ids, + dtype=wp.float32, + device=self.device, ) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK - if joint_mask is None: - joint_mask = self._data.ALL_JOINT_MASK + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) + joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + # None masks are handled within the kernel. # set into simulation if isinstance(limits, float): self._update_batched_array_with_value_masked( @@ -1110,21 +1016,18 @@ def write_joint_armature_to_sim( """ # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. if isinstance(armature, torch.Tensor): - armature, env_mask, joint_mask = self._torch_to_warp_dual_index( + armature = make_complete_data_from_torch_dual_index( armature, self.num_instances, self.num_joints, env_ids, joint_ids, - env_mask, - joint_mask, dtype=wp.float32, + device=self.device, ) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK - if joint_mask is None: - joint_mask = self._data.ALL_JOINT_MASK + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) + joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + # None masks are handled within the kernel. # set into simulation if isinstance(armature, float): self._update_batched_array_with_value_masked( @@ -1165,21 +1068,18 @@ def write_joint_friction_coefficient_to_sim( """ # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. if isinstance(joint_friction_coeff, torch.Tensor): - joint_friction_coeff, env_mask, joint_mask = self._torch_to_warp_dual_index( + joint_friction_coeff = make_complete_data_from_torch_dual_index( joint_friction_coeff, self.num_instances, self.num_joints, env_ids, joint_ids, - env_mask, - joint_mask, dtype=wp.float32, + device=self.device, ) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK - if joint_mask is None: - joint_mask = self._data.ALL_JOINT_MASK + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) + joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + # None masks are handled within the kernel. # set into simulation if isinstance(joint_friction_coeff, float): self._update_batched_array_with_value_masked( @@ -1210,21 +1110,18 @@ def write_joint_dynamic_friction_coefficient_to_sim( ) -> None: # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. if isinstance(joint_dynamic_friction_coeff, torch.Tensor): - joint_dynamic_friction_coeff, env_mask, joint_mask = self._torch_to_warp_dual_index( + joint_dynamic_friction_coeff = make_complete_data_from_torch_dual_index( joint_dynamic_friction_coeff, self.num_instances, self.num_joints, env_ids, joint_ids, - env_mask, - joint_mask, dtype=wp.float32, + device=self.device, ) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK - if joint_mask is None: - joint_mask = self._data.ALL_JOINT_MASK + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) + joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) + # None masks are handled within the kernel. # set into simulation if isinstance(joint_dynamic_friction_coeff, float): self._update_batched_array_with_value_masked( @@ -1319,14 +1216,12 @@ def set_masses( """ # raise NotImplementedError() if isinstance(masses, torch.Tensor): - masses, env_mask, body_mask = self._torch_to_warp_dual_index( - masses, self.num_instances, self.num_bodies, env_ids, body_ids, env_mask, body_mask, dtype=wp.float32 + masses = make_complete_data_from_torch_dual_index( + masses, self.num_instances, self.num_bodies, env_ids, body_ids, dtype=wp.float32, device=self.device ) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK - if body_mask is None: - body_mask = self._data.ALL_BODY_MASK + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) + body_mask = make_masks_from_torch_ids(self.num_bodies, body_ids, body_mask, device=self.device) + # None masks are handled within the kernel. self._update_batched_array_with_batched_array_masked( masses, self._data.body_mass, env_mask, body_mask, (self.num_instances, self.num_bodies) ) @@ -1349,16 +1244,13 @@ def set_coms( body_mask: The body mask. Shape is (num_bodies). env_mask: The environment mask. Shape is (num_instances,). """ - # raise NotImplementedError() if isinstance(coms, torch.Tensor): - coms, env_mask, body_mask = self._torch_to_warp_dual_index( - coms, self.num_instances, self.num_bodies, env_ids, body_ids, env_mask, body_mask, dtype=wp.vec3f + coms = make_complete_data_from_torch_dual_index( + coms, self.num_instances, self.num_bodies, env_ids, body_ids, dtype=wp.vec3f, device=self.device ) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK - if body_mask is None: - body_mask = self._data.ALL_BODY_MASK + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) + body_mask = make_masks_from_torch_ids(self.num_bodies, body_ids, body_mask, device=self.device) + # None masks are handled within the kernel. self._update_batched_array_with_batched_array_masked( coms, self._data.body_com_pos_b, env_mask, body_mask, (self.num_instances, self.num_bodies) ) @@ -1382,19 +1274,18 @@ def set_inertias( env_mask: The environment mask. Shape is (num_instances,). """ if isinstance(inertias, torch.Tensor): - inertias, env_mask, body_mask = self._torch_to_warp_dual_index( - inertias, self.num_instances, self.num_bodies, env_ids, body_ids, env_mask, body_mask, dtype=wp.mat33f + inertias = make_complete_data_from_torch_dual_index( + inertias, self.num_instances, self.num_bodies, env_ids, body_ids, dtype=wp.mat33f, device=self.device ) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK - if body_mask is None: - body_mask = self._data.ALL_BODY_MASK + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) + body_mask = make_masks_from_torch_ids(self.num_bodies, body_ids, body_mask, device=self.device) + # None masks are handled within the kernel. self._update_batched_array_with_batched_array_masked( inertias, self._data.body_inertia, env_mask, body_mask, (self.num_instances, self.num_bodies) ) NewtonManager.add_model_change(SolverNotifyFlags.BODY_PROPERTIES) + # TODO: Plug-in the Wrench code from Isaac Lab once the PR gets in. def set_external_force_and_torque( self, forces: torch.Tensor | wp.array, @@ -1403,6 +1294,8 @@ def set_external_force_and_torque( env_ids: Sequence[int] | None = None, body_mask: wp.array | None = None, env_mask: wp.array | None = None, + positions: torch.Tensor | wp.array | None = None, + is_global: bool = False, ) -> None: """Set external force and torque to apply on the asset's bodies in their local frame. @@ -1431,19 +1324,25 @@ def set_external_force_and_torque( env_ids: The environment indices to set the targets for. Defaults to None (all environments). body_mask: The body mask. Shape is (num_bodies). env_mask: The environment mask. Shape is (num_instances,). + positions: External wrench positions in bodies' local frame. Shape is (len(env_ids), len(body_ids), 3). + Defaults to None. If None, the external wrench is applied at the center of mass of the body. + is_global: Whether to apply the external wrench in the global frame. Defaults to False. If set to False, + the external wrench is applied in the link frame of the articulations' bodies. """ # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. env_mask_ = None body_mask_ = None if isinstance(forces, torch.Tensor) or isinstance(torques, torch.Tensor): if forces is not None: - forces, env_mask_, body_mask_ = self._torch_to_warp_dual_index( - forces, self.num_instances, self.num_bodies, env_ids, body_ids, env_mask, body_mask, dtype=wp.vec3f + forces = make_complete_data_from_torch_dual_index( + forces, self.num_instances, self.num_bodies, env_ids, body_ids, dtype=wp.vec3f, device=self.device ) if torques is not None: - torques, env_mask_, body_mask_ = self._torch_to_warp_dual_index( - torques, self.num_instances, self.num_bodies, env_ids, body_ids, env_mask, body_mask, dtype=wp.vec3f + torques = make_complete_data_from_torch_dual_index( + torques, self.num_instances, self.num_bodies, env_ids, body_ids, dtype=wp.vec3f, device=self.device ) + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) + body_mask = make_masks_from_torch_ids(self.num_bodies, body_ids, body_mask, device=self.device) # solve for None masks if env_mask_ is None: env_mask_ = self._data.ALL_ENV_MASK @@ -1497,14 +1396,11 @@ def set_joint_position_target( """ # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. if isinstance(target, torch.Tensor): - target, env_mask, joint_mask = self._torch_to_warp_dual_index( - target, self.num_instances, self.num_joints, env_ids, joint_ids, env_mask, joint_mask, dtype=wp.float32 + target = make_complete_data_from_torch_dual_index( + target, self.num_instances, self.num_joints, env_ids, joint_ids, dtype=wp.float32 ) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK - if joint_mask is None: - joint_mask = self._data.ALL_JOINT_MASK + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) + joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) # set into the actuator target buffer wp.launch( update_array2D_with_array2D_masked, @@ -1539,14 +1435,11 @@ def set_joint_velocity_target( """ # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. if isinstance(target, torch.Tensor): - target, env_mask, joint_mask = self._torch_to_warp_dual_index( - target, self.num_instances, self.num_joints, env_ids, joint_ids, env_mask, joint_mask, dtype=wp.float32 + target = make_complete_data_from_torch_dual_index( + target, self.num_instances, self.num_joints, env_ids, joint_ids, dtype=wp.float32, device=self.device ) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK - if joint_mask is None: - joint_mask = self._data.ALL_JOINT_MASK + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) + joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) # set into the actuator target buffer self._update_batched_array_with_batched_array_masked( target, self._data.actuator_velocity_target, env_mask, joint_mask, (self.num_instances, self.num_joints) @@ -1574,14 +1467,11 @@ def set_joint_effort_target( """ # Resolve indices into mask, convert from partial data to complete data, handles the conversion to warp. if isinstance(target, torch.Tensor): - target, env_mask, joint_mask = self._torch_to_warp_dual_index( - target, self.num_instances, self.num_joints, env_ids, joint_ids, env_mask, joint_mask, dtype=wp.float32 + target = make_complete_data_from_torch_dual_index( + target, self.num_instances, self.num_joints, env_ids, joint_ids, dtype=wp.float32, device=self.device ) - # solve for None masks - if env_mask is None: - env_mask = self._data.ALL_ENV_MASK - if joint_mask is None: - joint_mask = self._data.ALL_JOINT_MASK + env_mask = make_masks_from_torch_ids(self.num_instances, env_ids, env_mask, device=self.device) + joint_mask = make_masks_from_torch_ids(self.num_joints, joint_ids, joint_mask, device=self.device) # set into the actuator effort target buffer self._update_batched_array_with_batched_array_masked( target, self._data.actuator_effort_target, env_mask, joint_mask, (self.num_instances, self.num_joints) @@ -1934,21 +1824,15 @@ def _initialize_impl(self): self.update(0.0) # log joint information self._log_articulation_info() - - # Offsets the spawned pose by the default root pose prior to initializing the solver. This ensures that the - # solver is initialized at the correct pose, avoiding potential miscalculations in the maximum number of - # constraints or contact required to run the simulation. - # TODO: Do this is warp directly? - generated_pose = wp.to_torch(self._data._default_root_pose).clone() - generated_pose[:, :2] += wp.to_torch(self._root_view.get_root_transforms(NewtonManager.get_model()))[:, :2] - self._root_view.set_root_transforms(NewtonManager.get_state_0(), generated_pose) - self._root_view.set_root_transforms(NewtonManager.get_model(), generated_pose) + # Let the articulation data know that it is fully instantiated and ready to use. + self._data.is_primed = True def _create_buffers(self, *args, **kwargs): self._ALL_INDICES = torch.arange(self.num_instances, dtype=torch.long, device=self.device) wp.launch( update_soft_joint_pos_limits, dim=(self.num_instances, self.num_joints), + device=self.device, inputs=[ self._data.joint_pos_limits_lower, self._data.joint_pos_limits_upper, @@ -1986,6 +1870,7 @@ def _process_cfg(self): wp.launch( update_array2D_with_array1D_indexed, dim=(self.num_instances, len(indices_list)), + device=self.device, inputs=[ wp.array(values_list, dtype=wp.float32, device=self.device), self._data.default_joint_pos, @@ -2000,6 +1885,7 @@ def _process_cfg(self): wp.launch( update_array2D_with_array1D_indexed, dim=(self.num_instances, len(indices_list)), + device=self.device, inputs=[ wp.array(values_list, dtype=wp.float32, device=self.device), self._data.default_joint_vel, @@ -2008,6 +1894,19 @@ def _process_cfg(self): ], ) + # Offsets the spawned pose by the default root pose prior to initializing the solver. This ensures that the + # solver is initialized at the correct pose, avoiding potential miscalculations in the maximum number of + # constraints or contact required to run the simulation. + # TODO: Do this is warp directly? + generated_pose = wp.to_torch(self._data._default_root_pose).clone() + generated_pose[:, :2] += wp.to_torch(self._root_view.get_root_transforms(NewtonManager.get_model()))[:, :2] + self._root_view.set_root_transforms( + NewtonManager.get_state_0(), wp.from_torch(generated_pose, dtype=wp.transformf) + ) + self._root_view.set_root_transforms( + NewtonManager.get_model(), wp.from_torch(generated_pose, dtype=wp.transformf) + ) + """ Internal simulation callbacks. """ @@ -2038,7 +1937,7 @@ def _process_actuators_cfg(self): # type annotation for type checkers actuator_cfg: ActuatorBaseCfg # create actuator group - joint_mask, joint_names, joint_indices = self._find_joints(actuator_cfg.joint_names_expr) + joint_mask, joint_names, joint_indices = self.find_joints(actuator_cfg.joint_names_expr) # check if any joints are found if len(joint_names) == 0: raise ValueError( @@ -2165,10 +2064,11 @@ def _validate_cfg(self): if len(violated_indices) > 0: # prepare message for violated joints msg = "The following joints have default positions out of the limits: \n" + default_joint_pos = wp.to_torch(self._data._default_joint_pos) for idx in violated_indices: joint_name = self.data.joint_names[idx] joint_limit = joint_pos_limits[idx] - joint_pos = self.data.default_joint_pos[0, idx] + joint_pos = default_joint_pos[0, idx] # add to message msg += f"\t- '{joint_name}': {joint_pos:.3f} not in [{joint_limit[0]:.3f}, {joint_limit[1]:.3f}]\n" raise ValueError(msg) @@ -2255,6 +2155,7 @@ def _update_array_with_value( source, target, ], + device=self.device, ) def _update_array_with_value_masked( @@ -2280,6 +2181,7 @@ def _update_array_with_value_masked( target, mask, ], + device=self.device, ) def _update_array_with_array_masked(self, source: wp.array, target: wp.array, mask: wp.array, dim: int): @@ -2298,6 +2200,7 @@ def _update_array_with_array_masked(self, source: wp.array, target: wp.array, ma target, mask, ], + device=self.device, ) def _update_batched_array_with_batched_array_masked( @@ -2321,6 +2224,7 @@ def _update_batched_array_with_batched_array_masked( mask_1, mask_2, ], + device=self.device, ) def _update_batched_array_with_value_masked( @@ -2349,73 +2253,13 @@ def _update_batched_array_with_value_masked( mask_1, mask_2, ], + device=self.device, ) - def _find_bodies( - self, name_keys: str | Sequence[str], preserve_order: bool = False - ) -> tuple[wp.array, list[str], list[int]]: - """Find bodies in the articulation based on the name keys. - - Please check the :meth:`isaaclab.utils.string_utils.resolve_matching_names` function for more - information on the name matching. - - Args: - name_keys: A regular expression or a list of regular expressions to match the body names. - preserve_order: Whether to preserve the order of the name keys in the output. Defaults to False. - - Returns: - A tuple of lists containing the body mask, names, and indices. - """ - indices, names = string_utils.resolve_matching_names(name_keys, self.body_names, preserve_order) - self._data.BODY_MASK.fill_(False) - mask = wp.clone(self._data.BODY_MASK) - wp.launch( - generate_mask_from_ids, - dim=(len(indices),), - inputs=[ - mask, - wp.array(indices, dtype=wp.int32, device=self._device), - ], - ) - return mask, names, indices - - def _find_joints( - self, name_keys: str | Sequence[str], joint_subset: list[str] | None = None, preserve_order: bool = False - ) -> tuple[wp.array, list[str], list[int]]: - """Find joints in the articulation based on the name keys. - - Please see the :func:`isaaclab.utils.string.resolve_matching_names` function for more information - on the name matching. - - Args: - name_keys: A regular expression or a list of regular expressions to match the joint names. - joint_subset: A subset of joints to search for. Defaults to None, which means all joints - in the articulation are searched. - preserve_order: Whether to preserve the order of the name keys in the output. Defaults to False. - - Returns: - A tuple of lists containing the joint mask, names, and indices. - """ - if joint_subset is None: - joint_subset = self.joint_names - # find joints - indices, names = string_utils.resolve_matching_names(name_keys, joint_subset, preserve_order) - self._data.JOINT_MASK.fill_(False) - mask = wp.clone(self._data.JOINT_MASK) - wp.launch( - generate_mask_from_ids, - dim=(len(indices),), - inputs=[ - mask, - wp.array(indices, dtype=wp.int32, device=self._device), - ], - ) - return mask, names, indices - def _write_joint_position_limit_to_sim( self, - upper_limits: wp.array | float, lower_limits: wp.array | float, + upper_limits: wp.array | float, joint_mask: wp.array, env_mask: wp.array, ) -> None: @@ -2430,60 +2274,52 @@ def _write_joint_position_limit_to_sim( # note: This function isn't setting the values for actuator models. (#128) # resolve indices - if isinstance(upper_limits, float) and isinstance(lower_limits, float): - # update default joint pos to stay within the new limits + if isinstance(lower_limits, float): self._update_batched_array_with_value_masked( - wp.vec2f(upper_limits, lower_limits), - self._data.default_joint_pos, + lower_limits, + self._data.joint_pos_limits_lower, env_mask, joint_mask, (self.num_instances, self.num_joints), ) - # set into simulation - wp.launch( - update_joint_limits_value_vec2f, - dim=(self.num_instances, self.num_joints), - inputs=[ - wp.vec2f(upper_limits, lower_limits), - self.cfg.soft_joint_pos_limit_factor, - self._data.joint_pos_limits_lower, - self._data.joint_pos_limits_upper, - self._data.soft_joint_pos_limits, - env_mask, - joint_mask, - ], - ) - elif isinstance(upper_limits, wp.array) and isinstance(lower_limits, wp.array): - # update default joint pos to stay within the new limits + else: self._update_batched_array_with_batched_array_masked( lower_limits, + self._data.joint_pos_limits_lower, + env_mask, + joint_mask, + (self.num_instances, self.num_joints), + ) + if isinstance(upper_limits, float): + self._update_batched_array_with_value_masked( upper_limits, - self._data.default_joint_pos, + self._data.joint_pos_limits_upper, env_mask, joint_mask, (self.num_instances, self.num_joints), ) - # set into simulation - wp.launch( - update_joint_limits, - dim=(self.num_instances, self.num_joints), - inputs=[ - lower_limits, - upper_limits, - self.cfg.soft_joint_pos_limit_factor, - self._data.joint_pos_limits_lower, - self._data.joint_pos_limits_upper, - self._data.soft_joint_pos_limits, - env_mask, - joint_mask, - ], + else: + self._update_batched_array_with_batched_array_masked( + upper_limits, + self._data.joint_pos_limits_upper, + env_mask, + joint_mask, + (self.num_instances, self.num_joints), ) - @warn_overhead_cost( - "N/A", - "Launches a kernel to split the state into pose and velocity. Consider setting the pose and velocity" - " independently instead.", - ) + # Update soft joint limits + wp.launch( + update_soft_joint_pos_limits, + dim=(self.num_instances, self.num_joints), + inputs=[ + self._data.joint_pos_limits_lower, + self._data.joint_pos_limits_upper, + self._data.soft_joint_pos_limits, + self.cfg.soft_joint_pos_limit_factor, + ], + device=self.device, + ) + def _split_state( self, state: wp.array, @@ -2500,19 +2336,13 @@ def _split_state( tmp_velocity = wp.zeros((self.num_instances,), dtype=wp.spatial_vectorf, device=self._device) wp.launch( - split_state_to_pose, + split_state_to_pose_and_velocity, dim=self.num_instances, inputs=[ state, tmp_pose, - ], - ) - wp.launch( - split_state_to_velocity, - dim=self.num_instances, - inputs=[ - state, tmp_velocity, ], + device=self.device, ) return tmp_pose, tmp_velocity diff --git a/source/isaaclab_newton/isaaclab_newton/assets/articulation/articulation_data.py b/source/isaaclab_newton/isaaclab_newton/assets/articulation/articulation_data.py index fd89f1c8c32..b533e11a31c 100644 --- a/source/isaaclab_newton/isaaclab_newton/assets/articulation/articulation_data.py +++ b/source/isaaclab_newton/isaaclab_newton/assets/articulation/articulation_data.py @@ -87,6 +87,7 @@ def __init__(self, root_view, device: str): # Set initial time stamp self._sim_timestamp = 0.0 + self._is_primed = False # obtain global simulation view gravity = wp.to_torch(NewtonManager.get_model().gravity)[0] gravity_dir = math_utils.normalize(gravity.unsqueeze(0)).squeeze(0) @@ -101,6 +102,27 @@ def __init__(self, root_view, device: str): self._create_simulation_bindings() self._create_buffers() + @property + def is_primed(self) -> bool: + """Whether the articulation data is fully instantiated and ready to use.""" + return self._is_primed + + @is_primed.setter + def is_primed(self, value: bool): + """Set whether the articulation data is fully instantiated and ready to use. + + ..note:: Once this quantity is set to True, it cannot be changed. + + Args: + value: Whether the articulation data is fully instantiated and ready to use. + + Raises: + RuntimeError: If the articulation data is already fully instantiated and ready to use. + """ + if self._is_primed: + raise RuntimeError("Cannot set is_primed after instantiation.") + self._is_primed = value + ## # Names. ## @@ -131,6 +153,22 @@ def default_root_pose(self) -> wp.array(dtype=wp.transformf): """ return self._default_root_pose + @default_root_pose.setter + def default_root_pose(self, value: wp.array(dtype=wp.transformf)): + """Set default root pose ``[pos, quat]`` in the local environment frame. + + ..note:: Once this quantity is set to True, it cannot be changed. + + Args: + value: Default root pose ``[pos, quat]`` in the local environment frame. + + Raises: + RuntimeError: If the articulation data is already fully instantiated and ready to use. + """ + if self._is_primed: + raise RuntimeError("Cannot set default root pose after instantiation.") + self._default_root_pose = value + @property def default_root_vel(self) -> wp.array(dtype=wp.spatial_vectorf): """Default root velocity ``[lin_vel, ang_vel]`` in the local environment frame. Shape is (num_instances, 6). @@ -141,6 +179,22 @@ def default_root_vel(self) -> wp.array(dtype=wp.spatial_vectorf): """ return self._default_root_vel + @default_root_vel.setter + def default_root_vel(self, value: wp.array(dtype=wp.spatial_vectorf)): + """Set default root velocity ``[lin_vel, ang_vel]`` in the local environment frame. + + ..note:: Once this quantity is set to True, it cannot be changed. + + Args: + value: Default root velocity ``[lin_vel, ang_vel]`` in the local environment frame. + + Raises: + RuntimeError: If the articulation data is already fully instantiated and ready to use. + """ + if self._is_primed: + raise RuntimeError("Cannot set default root velocity after instantiation.") + self._default_root_vel = value + @property def default_joint_pos(self) -> wp.array(dtype=wp.float32): """Default joint positions of all joints. Shape is (num_instances, num_joints). @@ -149,6 +203,22 @@ def default_joint_pos(self) -> wp.array(dtype=wp.float32): """ return self._default_joint_pos + @default_joint_pos.setter + def default_joint_pos(self, value: wp.array(dtype=wp.float32)): + """Set default joint positions of all joints. + + ..note:: Once this quantity is set to True, it cannot be changed. + + Args: + value: Default joint positions of all joints. + + Raises: + RuntimeError: If the articulation data is already fully instantiated and ready to use. + """ + if self._is_primed: + raise RuntimeError("Cannot set default joint positions after instantiation.") + self._default_joint_pos = value + @property def default_joint_vel(self) -> wp.array(dtype=wp.float32): """Default joint velocities of all joints. Shape is (num_instances, num_joints). @@ -157,9 +227,25 @@ def default_joint_vel(self) -> wp.array(dtype=wp.float32): """ return self._default_joint_vel - ### + @default_joint_vel.setter + def default_joint_vel(self, value: wp.array(dtype=wp.float32)): + """Set default joint velocities of all joints. + + ..note:: Once this quantity is set to True, it cannot be changed. + + Args: + value: Default joint velocities of all joints. + + Raises: + RuntimeError: If the articulation data is already fully instantiated and ready to use. + """ + if self._is_primed: + raise RuntimeError("Cannot set default joint velocities after instantiation.") + self._default_joint_vel = value + + ## # Joint commands. -- Set into the simulation - ### + ## @property def joint_pos_target(self) -> wp.array(dtype=wp.float32): @@ -284,22 +370,36 @@ def joint_pos_limits_upper(self) -> wp.array(dtype=wp.float32): return self._sim_bind_joint_pos_limits_upper @property + @warn_overhead_cost( + "joint_pos_limits", + "Launches a kernel to compute the joint position limits from the lower and upper limits. Consider using the" + " joint_pos_limits_lower and joint_pos_limits_upper properties instead.", + ) def joint_pos_limits(self) -> wp.array(dtype=wp.vec2f): """Joint position limits provided to the simulation. Shape is (num_instances, num_joints, 2). The limits are in the order :math:`[lower, upper]`. + + .. caution:: This property is computed on-the-fly, and while it returns a pointer, writing to that pointer + will not affect change the joint position limits. To change the joint position limits, use the + :attr:`joint_pos_limits_lower` and :attr:`joint_pos_limits_upper` properties. """ - out = wp.zeros((self._root_view.count, self._root_view.joint_dof_count), dtype=wp.vec2f, device=self.device) + if self._joint_pos_limits is None: + self._joint_pos_limits = wp.zeros( + (self._root_view.count, self._root_view.joint_dof_count), dtype=wp.vec2f, device=self.device + ) + wp.launch( make_joint_pos_limits_from_lower_and_upper_limits, dim=(self._root_view.count, self._root_view.joint_dof_count), inputs=[ self._sim_bind_joint_pos_limits_lower, self._sim_bind_joint_pos_limits_upper, - out, + self._joint_pos_limits, ], + device=self.device, ) - return out + return self._joint_pos_limits @property def joint_vel_limits(self) -> wp.array(dtype=wp.float32): @@ -507,7 +607,8 @@ def root_state_w(self) -> wp.array(dtype=vec13f): The pose is of the articulation root's actor frame relative to the world. The velocity is of the articulation root's center of mass frame. """ - state = wp.zeros((self._root_view.count), dtype=vec13f, device=self.device) + if self._root_state_w is None: + self._root_state_w = wp.zeros((self._root_view.count), dtype=vec13f, device=self.device) wp.launch( combine_pose_and_velocity_to_state, dim=(self._root_view.count,), @@ -515,10 +616,10 @@ def root_state_w(self) -> wp.array(dtype=vec13f): inputs=[ self._sim_bind_root_link_pose_w, self._sim_bind_root_com_vel_w, - state, + self._root_state_w, ], ) - return state + return self._root_state_w @property @warn_overhead_cost( @@ -535,7 +636,9 @@ def root_link_state_w(self) -> wp.array(dtype=vec13f): The pose is of the articulation root's actor frame relative to the world. The velocity is of the articulation root's actor frame. """ - state = wp.zeros((self._root_view.count), dtype=vec13f, device=self.device) + if self._root_link_state_w is None: + self._root_link_state_w = wp.zeros((self._root_view.count), dtype=vec13f, device=self.device) + wp.launch( combine_pose_and_velocity_to_state, dim=(self._root_view.count,), @@ -543,10 +646,10 @@ def root_link_state_w(self) -> wp.array(dtype=vec13f): inputs=[ self._sim_bind_root_link_pose_w, self.root_link_vel_w, - state, + self._root_link_state_w, ], ) - return state + return self._root_link_state_w @property @warn_overhead_cost( @@ -563,7 +666,9 @@ def root_com_state_w(self) -> wp.array(dtype=vec13f): The pose is of the articulation root's center of mass frame relative to the world. The velocity is of the articulation root's center of mass frame. """ - state = wp.zeros((self._root_view.count), dtype=vec13f, device=self.device) + if self._root_com_state_w is None: + self._root_com_state_w = wp.zeros((self._root_view.count), dtype=vec13f, device=self.device) + wp.launch( combine_pose_and_velocity_to_state, dim=(self._root_view.count,), @@ -571,10 +676,10 @@ def root_com_state_w(self) -> wp.array(dtype=vec13f): inputs=[ self.root_com_pose_w, self._sim_bind_root_com_vel_w, - state, + self._root_com_state_w, ], ) - return state + return self._root_com_state_w ## # Body state properties. @@ -756,6 +861,7 @@ def body_com_acc_w(self) -> wp.array(dtype=wp.spatial_vectorf): wp.launch( derive_body_acceleration_from_velocity_batched, dim=(self._root_view.count, self._root_view.link_count), + device=self.device, inputs=[ self._sim_bind_body_com_vel_w, self._previous_body_com_vel, @@ -765,8 +871,8 @@ def body_com_acc_w(self) -> wp.array(dtype=wp.spatial_vectorf): ) # set the buffer data and timestamp self._body_com_acc_w.timestamp = self._sim_timestamp - # update the previous body velocity for next finite differencing - wp.copy(self._previous_body_com_vel, self._sim_bind_body_com_vel_w) + # update the previous velocity + self._previous_body_com_vel.assign(self._sim_bind_body_com_vel_w) return self._body_com_acc_w.data @property @@ -781,6 +887,7 @@ def body_com_pose_b(self) -> wp.array(dtype=wp.transformf): wp.launch( generate_pose_from_position_with_unit_quaternion_batched, dim=(self._root_view.count, self._root_view.link_count), + device=self.device, inputs=[ self._sim_bind_body_com_pos_b, out, @@ -789,6 +896,7 @@ def body_com_pose_b(self) -> wp.array(dtype=wp.transformf): return out # TODO: Make sure this is implemented when the feature is available in Newton. + # TODO: Waiting on https://github.com/newton-physics/newton/pull/1161 ETA: early JAN 2026. @property def body_incoming_joint_wrench_b(self) -> wp.array(dtype=wp.spatial_vectorf): """Joint reaction wrench applied from body parent to child body in parent body frame. @@ -820,6 +928,7 @@ def joint_acc(self) -> wp.array(dtype=wp.float32): wp.launch( derive_joint_acceleration_from_velocity, dim=(self._root_view.count, self._root_view.joint_dof_count), + device=self.device, inputs=[ self._sim_bind_joint_vel, self._previous_joint_vel, @@ -828,8 +937,8 @@ def joint_acc(self) -> wp.array(dtype=wp.float32): ], ) self._joint_acc.timestamp = self._sim_timestamp - # update the previous joint velocity for next finite differencing - wp.copy(self._previous_joint_vel, self._sim_bind_joint_vel) + # update the previous joint velocity + self._previous_joint_vel.assign(self._sim_bind_joint_vel) return self._joint_acc.data ## @@ -843,6 +952,7 @@ def projected_gravity_b(self) -> wp.array(dtype=wp.vec3f): wp.launch( project_vec_from_pose_single, dim=self._root_view.count, + device=self.device, inputs=[ self.GRAVITY_VEC_W, self._sim_bind_root_link_pose_w, @@ -865,6 +975,7 @@ def heading_w(self) -> wp.array(dtype=wp.float32): wp.launch( compute_heading, dim=self._root_view.count, + device=self.device, inputs=[ self.FORWARD_VEC_B, self._sim_bind_root_link_pose_w, @@ -885,6 +996,7 @@ def root_link_vel_b(self) -> wp.array(dtype=wp.spatial_vectorf): wp.launch( project_velocities_to_frame, dim=self._root_view.count, + device=self.device, inputs=[ self.root_link_vel_w, self._sim_bind_root_link_pose_w, @@ -905,6 +1017,7 @@ def root_com_vel_b(self) -> wp.array(dtype=wp.spatial_vectorf): wp.launch( project_velocities_to_frame, dim=self._root_view.count, + device=self.device, inputs=[ self._sim_bind_root_com_vel_w, self._sim_bind_root_link_pose_w, @@ -935,7 +1048,11 @@ def root_link_lin_vel_b(self) -> wp.array(dtype=wp.vec3f): if data.is_contiguous: # Create a memory view of the data self._root_link_lin_vel_b = wp.array( - ptr=data.ptr, dtype=wp.vec3f, shape=data.shape, strides=data.strides + ptr=data.ptr, + dtype=wp.vec3f, + shape=data.shape, + strides=data.strides, + device=self.device, ) else: # Create a new buffer @@ -946,6 +1063,7 @@ def root_link_lin_vel_b(self) -> wp.array(dtype=wp.vec3f): wp.launch( split_spatial_vectory_array_to_linear_velocity_array, dim=self._root_view.count, + device=self.device, inputs=[ data, self._root_link_lin_vel_b, @@ -973,7 +1091,11 @@ def root_link_ang_vel_b(self) -> wp.array(dtype=wp.vec3f): if data.is_contiguous: # Create a memory view of the data self._root_link_ang_vel_b = wp.array( - ptr=data.ptr + 3 * 4, dtype=wp.vec3f, shape=data.shape, strides=data.strides + ptr=data.ptr + 3 * 4, + dtype=wp.vec3f, + shape=data.shape, + strides=data.strides, + device=self.device, ) else: # Create a new buffer @@ -984,6 +1106,7 @@ def root_link_ang_vel_b(self) -> wp.array(dtype=wp.vec3f): wp.launch( split_spatial_vectory_array_to_angular_velocity_array, dim=self._root_view.count, + device=self.device, inputs=[ data, self._root_link_ang_vel_b, @@ -1011,7 +1134,11 @@ def root_com_lin_vel_b(self) -> wp.array(dtype=wp.vec3f): if data.is_contiguous: # Create a memory view of the data self._root_com_lin_vel_b = wp.array( - ptr=data.ptr, dtype=wp.vec3f, shape=data.shape, strides=data.strides + ptr=data.ptr, + dtype=wp.vec3f, + shape=data.shape, + strides=data.strides, + device=self.device, ) else: # Create a new buffer @@ -1022,6 +1149,7 @@ def root_com_lin_vel_b(self) -> wp.array(dtype=wp.vec3f): wp.launch( split_spatial_vectory_array_to_linear_velocity_array, dim=self._root_view.count, + device=self.device, inputs=[ data, self._root_com_lin_vel_b, @@ -1049,7 +1177,11 @@ def root_com_ang_vel_b(self) -> wp.array(dtype=wp.vec3f): if data.is_contiguous: # Create a memory view of the data self._root_com_ang_vel_b = wp.array( - ptr=data.ptr + 3 * 4, dtype=wp.vec3f, shape=data.shape, strides=data.strides + ptr=data.ptr + 3 * 4, + dtype=wp.vec3f, + shape=data.shape, + strides=data.strides, + device=self.device, ) else: # Create a new buffer @@ -1060,6 +1192,7 @@ def root_com_ang_vel_b(self) -> wp.array(dtype=wp.vec3f): wp.launch( split_spatial_vectory_array_to_angular_velocity_array, dim=self._root_view.count, + device=self.device, inputs=[ data, self._root_com_ang_vel_b, @@ -1092,6 +1225,7 @@ def root_link_pos_w(self) -> wp.array(dtype=wp.vec3f): dtype=wp.vec3f, shape=self._sim_bind_root_link_pose_w.shape, strides=self._sim_bind_root_link_pose_w.strides, + device=self.device, ) else: # Create a new buffer @@ -1102,6 +1236,7 @@ def root_link_pos_w(self) -> wp.array(dtype=wp.vec3f): wp.launch( split_transform_array_to_position_array, dim=self._root_view.count, + device=self.device, inputs=[ self._sim_bind_root_link_pose_w, self._root_link_pos_w, @@ -1131,6 +1266,7 @@ def root_link_quat_w(self) -> wp.array(dtype=wp.quatf): dtype=wp.quatf, shape=self._sim_bind_root_link_pose_w.shape, strides=self._sim_bind_root_link_pose_w.strides, + device=self.device, ) else: # Create a new buffer @@ -1141,6 +1277,7 @@ def root_link_quat_w(self) -> wp.array(dtype=wp.quatf): wp.launch( split_transform_array_to_quaternion_array, dim=self._root_view.count, + device=self.device, inputs=[ self._sim_bind_root_link_pose_w, self._root_link_quat_w, @@ -1167,7 +1304,11 @@ def root_link_lin_vel_w(self) -> wp.array(dtype=wp.vec3f): if data.is_contiguous: # Create a memory view of the data self._root_link_lin_vel_w = wp.array( - ptr=data.ptr, dtype=wp.vec3f, shape=data.shape, strides=data.strides + ptr=data.ptr, + dtype=wp.vec3f, + shape=data.shape, + strides=data.strides, + device=self.device, ) else: # Create a new buffer @@ -1178,6 +1319,7 @@ def root_link_lin_vel_w(self) -> wp.array(dtype=wp.vec3f): wp.launch( split_spatial_vectory_array_to_linear_velocity_array, dim=self._root_view.count, + device=self.device, inputs=[ data, self._root_link_lin_vel_w, @@ -1204,7 +1346,11 @@ def root_link_ang_vel_w(self) -> wp.array(dtype=wp.vec3f): if data.is_contiguous: # Create a memory view of the data self._root_link_ang_vel_w = wp.array( - ptr=data.ptr + 3 * 4, dtype=wp.vec3f, shape=data.shape, strides=data.strides + ptr=data.ptr + 3 * 4, + dtype=wp.vec3f, + shape=data.shape, + strides=data.strides, + device=self.device, ) else: # Create a new buffer @@ -1215,6 +1361,7 @@ def root_link_ang_vel_w(self) -> wp.array(dtype=wp.vec3f): wp.launch( split_spatial_vectory_array_to_angular_velocity_array, dim=self._root_view.count, + device=self.device, inputs=[ data, self._root_link_ang_vel_w, @@ -1240,7 +1387,9 @@ def root_com_pos_w(self) -> wp.array(dtype=wp.vec3f): if self._root_com_pos_w is None: if data.is_contiguous: # Create a memory view of the data - self._root_com_pos_w = wp.array(ptr=data.ptr, dtype=wp.vec3f, shape=data.shape, strides=data.strides) + self._root_com_pos_w = wp.array( + ptr=data.ptr, dtype=wp.vec3f, shape=data.shape, strides=data.strides, device=self.device + ) else: # Create a new buffer self._root_com_pos_w = wp.zeros((self._root_view.count,), dtype=wp.vec3f, device=self.device) @@ -1277,7 +1426,7 @@ def root_com_quat_w(self) -> wp.array(dtype=wp.quatf): if data.is_contiguous: # Create a memory view of the data self._root_com_quat_w = wp.array( - ptr=data.ptr + 3 * 4, dtype=wp.quatf, shape=data.shape, strides=data.strides + ptr=data.ptr + 3 * 4, dtype=wp.quatf, shape=data.shape, strides=data.strides, device=self.device ) else: # Create a new buffer @@ -1316,6 +1465,7 @@ def root_com_lin_vel_w(self) -> wp.array(dtype=wp.vec3f): dtype=wp.vec3f, shape=self._sim_bind_root_com_vel_w.shape, strides=self._sim_bind_root_com_vel_w.strides, + device=self.device, ) else: # Create a new buffer @@ -1326,6 +1476,7 @@ def root_com_lin_vel_w(self) -> wp.array(dtype=wp.vec3f): wp.launch( split_spatial_vectory_array_to_linear_velocity_array, dim=self._root_view.count, + device=self.device, inputs=[ self._sim_bind_root_com_vel_w, self._root_com_lin_vel_w, @@ -1354,6 +1505,7 @@ def root_com_ang_vel_w(self) -> wp.array(dtype=wp.vec3f): dtype=wp.vec3f, shape=self._sim_bind_root_com_vel_w.shape, strides=self._sim_bind_root_com_vel_w.strides, + device=self.device, ) else: # Create a new buffer @@ -1364,6 +1516,7 @@ def root_com_ang_vel_w(self) -> wp.array(dtype=wp.vec3f): wp.launch( split_spatial_vectory_array_to_angular_velocity_array, dim=self._root_view.count, + device=self.device, inputs=[ self._sim_bind_root_com_vel_w, self._root_com_ang_vel_w, @@ -1392,6 +1545,7 @@ def body_link_pos_w(self) -> wp.array(dtype=wp.vec3f): dtype=wp.vec3f, shape=self._sim_bind_body_link_pose_w.shape, strides=self._sim_bind_body_link_pose_w.strides, + device=self.device, ) else: # Create a new buffer @@ -1404,6 +1558,7 @@ def body_link_pos_w(self) -> wp.array(dtype=wp.vec3f): wp.launch( split_transform_batched_array_to_position_batched_array, dim=(self._root_view.count, self._root_view.link_count), + device=self.device, inputs=[ self._sim_bind_body_link_pose_w, self._body_link_pos_w, @@ -1433,6 +1588,7 @@ def body_link_quat_w(self) -> wp.array(dtype=wp.quatf): dtype=wp.quatf, shape=self._sim_bind_body_link_pose_w.shape, strides=self._sim_bind_body_link_pose_w.strides, + device=self.device, ) else: # Create a new buffer @@ -1445,6 +1601,7 @@ def body_link_quat_w(self) -> wp.array(dtype=wp.quatf): wp.launch( split_transform_batched_array_to_quaternion_batched_array, dim=(self._root_view.count, self._root_view.link_count), + device=self.device, inputs=[ self._sim_bind_body_link_pose_w, self._body_link_quat_w, @@ -1471,7 +1628,7 @@ def body_link_lin_vel_w(self) -> wp.array(dtype=wp.vec3f): if data.is_contiguous: # Create a memory view of the data self._body_link_lin_vel_w = wp.array( - ptr=data.ptr, dtype=wp.vec3f, shape=data.shape, strides=data.strides + ptr=data.ptr, dtype=wp.vec3f, shape=data.shape, strides=data.strides, device=self.device ) else: # Create a new buffer @@ -1484,6 +1641,7 @@ def body_link_lin_vel_w(self) -> wp.array(dtype=wp.vec3f): wp.launch( split_spatial_vectory_batched_array_to_linear_velocity_batched_array, dim=(self._root_view.count, self._root_view.link_count), + device=self.device, inputs=[ data, self._body_link_lin_vel_w, @@ -1510,7 +1668,7 @@ def body_link_ang_vel_w(self) -> wp.array(dtype=wp.vec3f): if data.is_contiguous: # Create a memory view of the data self._body_link_ang_vel_w = wp.array( - ptr=data.ptr + 3 * 4, dtype=wp.vec3f, shape=data.shape, strides=data.strides + ptr=data.ptr + 3 * 4, dtype=wp.vec3f, shape=data.shape, strides=data.strides, device=self.device ) else: # Create a new buffer @@ -1523,6 +1681,7 @@ def body_link_ang_vel_w(self) -> wp.array(dtype=wp.vec3f): wp.launch( split_spatial_vectory_batched_array_to_angular_velocity_batched_array, dim=(self._root_view.count, self._root_view.link_count), + device=self.device, inputs=[ data, self._body_link_ang_vel_w, @@ -1548,7 +1707,9 @@ def body_com_pos_w(self) -> wp.array(dtype=wp.vec3f): if self._body_com_pos_w is None: if data.is_contiguous: # Create a memory view of the data - self._body_com_pos_w = wp.array(ptr=data.ptr, dtype=wp.vec3f, shape=data.shape, strides=data.strides) + self._body_com_pos_w = wp.array( + ptr=data.ptr, dtype=wp.vec3f, shape=data.shape, strides=data.strides, device=self.device + ) else: # Create a new buffer self._body_com_pos_w = wp.zeros( @@ -1560,6 +1721,7 @@ def body_com_pos_w(self) -> wp.array(dtype=wp.vec3f): wp.launch( split_transform_batched_array_to_position_batched_array, dim=(self._root_view.count, self._root_view.link_count), + device=self.device, inputs=[ data, self._body_com_pos_w, @@ -1587,7 +1749,7 @@ def body_com_quat_w(self) -> wp.array(dtype=wp.quatf): if data.is_contiguous: # Create a memory view of the data self._body_com_quat_w = wp.array( - ptr=data.ptr + 3 * 4, dtype=wp.quatf, shape=data.shape, strides=data.strides + ptr=data.ptr + 3 * 4, dtype=wp.quatf, shape=data.shape, strides=data.strides, device=self.device ) else: # Create a new buffer @@ -1600,6 +1762,7 @@ def body_com_quat_w(self) -> wp.array(dtype=wp.quatf): wp.launch( split_transform_batched_array_to_quaternion_batched_array, dim=(self._root_view.count, self._root_view.link_count), + device=self.device, inputs=[ data, self._body_com_quat_w, @@ -1628,6 +1791,7 @@ def body_com_lin_vel_w(self) -> wp.array(dtype=wp.vec3f): dtype=wp.vec3f, shape=self._sim_bind_body_com_vel_w.shape, strides=self._sim_bind_body_com_vel_w.strides, + device=self.device, ) else: # Create a new buffer @@ -1640,6 +1804,7 @@ def body_com_lin_vel_w(self) -> wp.array(dtype=wp.vec3f): wp.launch( split_spatial_vectory_batched_array_to_linear_velocity_batched_array, dim=(self._root_view.count, self._root_view.link_count), + device=self.device, inputs=[ self._sim_bind_body_com_vel_w, self._body_com_lin_vel_w, @@ -1668,6 +1833,7 @@ def body_com_ang_vel_w(self) -> wp.array(dtype=wp.vec3f): dtype=wp.vec3f, shape=self._sim_bind_body_com_vel_w.shape, strides=self._sim_bind_body_com_vel_w.strides, + device=self.device, ) else: # Create a new buffer @@ -1680,6 +1846,7 @@ def body_com_ang_vel_w(self) -> wp.array(dtype=wp.vec3f): wp.launch( split_spatial_vectory_batched_array_to_angular_velocity_batched_array, dim=(self._root_view.count, self._root_view.link_count), + device=self.device, inputs=[ self._sim_bind_body_com_vel_w, self._body_com_ang_vel_w, @@ -1706,7 +1873,7 @@ def body_com_lin_acc_w(self) -> wp.array(dtype=wp.vec3f): if data.is_contiguous: # Create a memory view of the data self._body_com_lin_acc_w = wp.array( - ptr=data.ptr, dtype=wp.vec3f, shape=data.shape, strides=data.strides + ptr=data.ptr, dtype=wp.vec3f, shape=data.shape, strides=data.strides, device=self.device ) else: # Create a new buffer @@ -1719,6 +1886,7 @@ def body_com_lin_acc_w(self) -> wp.array(dtype=wp.vec3f): wp.launch( split_spatial_vectory_batched_array_to_linear_velocity_batched_array, dim=(self._root_view.count, self._root_view.link_count), + device=self.device, inputs=[ data, self._body_com_lin_acc_w, @@ -1745,7 +1913,11 @@ def body_com_ang_acc_w(self) -> wp.array(dtype=wp.vec3f): if data.is_contiguous: # Create a memory view of the data self._body_com_ang_acc_w = wp.array( - ptr=data.ptr + 3 * 4, dtype=wp.vec3f, shape=data.shape, strides=data.strides + ptr=data.ptr + 3 * 4, + dtype=wp.vec3f, + shape=data.shape, + strides=data.strides, + device=self.device, ) else: # Create a new buffer @@ -1758,6 +1930,7 @@ def body_com_ang_acc_w(self) -> wp.array(dtype=wp.vec3f): wp.launch( split_spatial_vectory_batched_array_to_angular_velocity_batched_array, dim=(self._root_view.count, self._root_view.link_count), + device=self.device, inputs=[ data, self._body_com_ang_acc_w, @@ -1784,18 +1957,15 @@ def body_com_quat_b(self) -> wp.array(dtype=wp.quatf): """Orientation (x, y, z, w) of the principle axis of inertia of all of the bodies in their respective link frames. Shape is (num_instances, num_bodies, 4). - This quantity is the orientation of the principles axes of inertia relative to its body's link frame. + This quantity is the orientation of the principles axes of inertia relative to its body's link frame. In Newton + this quantity is always a unit quaternion. """ - out = wp.zeros((self._root_view.count, self._root_view.link_count), dtype=wp.quatf, device=self.device) - wp.launch( - split_transform_batched_array_to_quaternion_batched_array, - dim=(self._root_view.count, self._root_view.link_count), - inputs=[ - self.body_com_pose_b, - out, - ], - ) - return out + if self._body_com_quat_b is None: + self._body_com_quat_b = wp.zeros( + (self._root_view.count, self._root_view.link_count), dtype=wp.quatf, device=self.device + ) + self._body_com_quat_b.fill_(wp.quat_identity(wp.float32)) + return self._body_com_quat_b ## # Backward compatibility. -- Deprecated properties. @@ -2094,22 +2264,30 @@ def _create_buffers(self) -> None: # Initialize the lazy buffers. # -- link frame w.r.t. world frame - self._root_link_vel_w = TimestampedWarpBuffer(shape=(n_view,), dtype=wp.spatial_vectorf) - self._root_link_vel_b = TimestampedWarpBuffer(shape=(n_view,), dtype=wp.spatial_vectorf) - self._projected_gravity_b = TimestampedWarpBuffer(shape=(n_view,), dtype=wp.vec3f) - self._heading_w = TimestampedWarpBuffer(shape=(n_view,), dtype=wp.float32) - self._body_link_vel_w = TimestampedWarpBuffer(shape=(n_view, n_link), dtype=wp.spatial_vectorf) + self._root_link_vel_w = TimestampedWarpBuffer(shape=(n_view,), dtype=wp.spatial_vectorf, device=self.device) + self._root_link_vel_b = TimestampedWarpBuffer(shape=(n_view,), dtype=wp.spatial_vectorf, device=self.device) + self._projected_gravity_b = TimestampedWarpBuffer(shape=(n_view,), dtype=wp.vec3f, device=self.device) + self._heading_w = TimestampedWarpBuffer(shape=(n_view,), dtype=wp.float32, device=self.device) + self._body_link_vel_w = TimestampedWarpBuffer( + shape=(n_view, n_link), dtype=wp.spatial_vectorf, device=self.device + ) # -- com frame w.r.t. world frame - self._root_com_pose_w = TimestampedWarpBuffer(shape=(n_view,), dtype=wp.transformf) - self._root_com_vel_b = TimestampedWarpBuffer(shape=(n_view,), dtype=wp.spatial_vectorf) - self._root_com_acc_w = TimestampedWarpBuffer(shape=(n_view,), dtype=wp.spatial_vectorf) - self._body_com_pose_w = TimestampedWarpBuffer(shape=(n_view, n_link), dtype=wp.transformf) - self._body_com_acc_w = TimestampedWarpBuffer(shape=(n_view, n_link), dtype=wp.spatial_vectorf) + self._root_com_pose_w = TimestampedWarpBuffer(shape=(n_view,), dtype=wp.transformf, device=self.device) + self._root_com_vel_b = TimestampedWarpBuffer(shape=(n_view,), dtype=wp.spatial_vectorf, device=self.device) + self._root_com_acc_w = TimestampedWarpBuffer(shape=(n_view,), dtype=wp.spatial_vectorf, device=self.device) + self._body_com_pose_w = TimestampedWarpBuffer(shape=(n_view, n_link), dtype=wp.transformf, device=self.device) + self._body_com_acc_w = TimestampedWarpBuffer( + shape=(n_view, n_link), dtype=wp.spatial_vectorf, device=self.device + ) # -- joint state - self._joint_acc = TimestampedWarpBuffer(shape=(n_view, n_dof), dtype=wp.float32) - # self._body_incoming_joint_wrench_b = TimestampedWarpBuffer(shape=(n_view, n_dof), dtype=wp.spatial_vectorf) - + self._joint_acc = TimestampedWarpBuffer(shape=(n_view, n_dof), dtype=wp.float32, device=self.device) + # self._body_incoming_joint_wrench_b = TimestampedWarpBuffer(shape=(n_view, n_dof), dtype=wp.spatial_vectorf, device=self.device) # Empty memory pre-allocations + self._joint_pos_limits = None + self._root_state_w = None + self._root_link_state_w = None + self._root_com_state_w = None + self._body_com_quat_b = None self._root_link_lin_vel_b = None self._root_link_ang_vel_b = None self._root_com_lin_vel_b = None diff --git a/source/isaaclab_newton/isaaclab_newton/assets/utils/shared.py b/source/isaaclab_newton/isaaclab_newton/assets/utils/shared.py new file mode 100644 index 00000000000..9ce43530c7b --- /dev/null +++ b/source/isaaclab_newton/isaaclab_newton/assets/utils/shared.py @@ -0,0 +1,69 @@ +# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import numpy as np +from collections.abc import Sequence + +import warp as wp + +import isaaclab.utils.string as string_utils + + +def find_bodies( + body_names: list[str], + name_keys: str | Sequence[str], + preserve_order: bool = False, + device: str = "cuda:0", +) -> tuple[wp.array, list[str], list[int]]: + """Find bodies in the articulation based on the name keys. + + Please check the :meth:`isaaclab.utils.string_utils.resolve_matching_names` function for more + information on the name matching. + + Args: + body_names: The names of all the bodies in the articulation / assets. + name_keys: A regular expression or a list of regular expressions to match the body names. + preserve_order: Whether to preserve the order of the name keys in the output. Defaults to False. + device: The device to use for the output mask. Defaults to "cuda:0". + Returns: + A tuple of lists containing the body mask, names, and indices. + """ + indices, names = string_utils.resolve_matching_names(name_keys, body_names, preserve_order) + mask = np.zeros(len(body_names), dtype=bool) + mask[indices] = True + mask = wp.array(mask, dtype=wp.bool, device=device) + return mask, names, indices + + +def find_joints( + joint_names: list[str], + name_keys: str | Sequence[str], + joint_subset: list[str] | None = None, + preserve_order: bool = False, + device: str = "cuda:0", +) -> tuple[wp.array, list[str], list[int]]: + """Find joints in the articulation based on the name keys. + + Please see the :func:`isaaclab.utils.string.resolve_matching_names` function for more information + on the name matching. + + Args: + joint_names: The names of all the joints in the articulation / assets. + name_keys: A regular expression or a list of regular expressions to match the joint names. + joint_subset: A subset of joints to search for. Defaults to None, which means all joints + in the articulation are searched. + preserve_order: Whether to preserve the order of the name keys in the output. Defaults to False. + device: The device to use for the output mask. Defaults to "cuda:0". + Returns: + A tuple of lists containing the joint mask, names, and indices. + """ + if joint_subset is None: + joint_subset = joint_names + # find joints + indices, names = string_utils.resolve_matching_names(name_keys, joint_subset, preserve_order) + mask = np.zeros(len(joint_names), dtype=bool) + mask[indices] = True + mask = wp.array(mask, dtype=wp.bool, device=device) + return mask, names, indices diff --git a/source/isaaclab_newton/isaaclab_newton/kernels/joint_kernels.py b/source/isaaclab_newton/isaaclab_newton/kernels/joint_kernels.py index be983ebaca8..1a3e2bb66e8 100644 --- a/source/isaaclab_newton/isaaclab_newton/kernels/joint_kernels.py +++ b/source/isaaclab_newton/isaaclab_newton/kernels/joint_kernels.py @@ -10,133 +10,6 @@ """ -@wp.kernel -def update_joint_array( - new_data: wp.array2d(dtype=wp.float32), - joint_data: wp.array2d(dtype=wp.float32), - env_mask: wp.array(dtype=bool), - joint_mask: wp.array(dtype=bool), -): - """ - Update the joint data for the given environment and joint indices from the newton data. - - .. note:: The :arg:`env_mask` length must be equal to the number of instances in the newton data. - The :arg:`joint_mask` length must be equal to the number of joints in the newton data. The :arg:`new_data` shape - must match the :arg:`joint_data` shape. - - Args: - new_data: The new data to update the joint data with. Shape is (num_instances, num_joints). - joint_data: The joint data to update. Shape is (num_instances, num_joints). (modified) - env_mask: The environment mask to update the joint data for. Shape is (num_instances,). - joint_mask: The joint mask to update the joint data for. Shape is (num_joints,). - """ - env_index, joint_index = wp.tid() - if env_mask[env_index] and joint_mask[joint_index]: - joint_data[env_index, joint_index] = new_data[env_index, joint_index] - - -@wp.kernel -def update_joint_array_int( - new_data: wp.array2d(dtype=wp.int32), - joint_data: wp.array2d(dtype=wp.int32), - env_mask: wp.array(dtype=bool), - joint_mask: wp.array(dtype=bool), -): - """ - Update the joint data for the given environment and joint indices from the newton data. - - .. note:: The :arg:`env_mask` length must be equal to the number of instances in the newton data. - The :arg:`joint_mask` length must be equal to the number of joints in the newton data. The :arg:`new_data` shape - must match the :arg:`joint_data` shape. - - Args: - new_data: The new data to update the joint data with. Shape is (num_instances, num_joints). - joint_data: The joint data to update. Shape is (num_instances, num_joints). (modified) - env_mask: The environment mask to update the joint data for. Shape is (num_instances,). - joint_mask: The joint mask to update the joint data for. Shape is (num_joints,). - """ - env_index, joint_index = wp.tid() - if env_mask[env_index] and joint_mask[joint_index]: - joint_data[env_index, joint_index] = new_data[env_index, joint_index] - - -@wp.kernel -def update_joint_array_with_value_array( - value: wp.array(dtype=wp.float32), - joint_data: wp.array2d(dtype=wp.float32), - env_mask: wp.array(dtype=bool), - joint_mask: wp.array(dtype=bool), -): - """Update the joint data for the given environment and joint indices with a value array. - - .. note:: The :arg:`env_mask` length must be equal to the number of instances in the newton data. - The :arg:`joint_mask` length must be equal to the number of joints in the newton data. The :arg:`value` shape - must be (num_joints,). - - Args: - value: The value array to update the joint data with. Shape is (num_joints,). - joint_data: The joint data to update. Shape is (num_instances, num_joints). (modified) - env_mask: The environment mask to update the joint data for. Shape is (num_instances,). - joint_mask: The joint mask to update the joint data for. Shape is (num_joints,). - """ - env_index, joint_index = wp.tid() - if env_mask[env_index] and joint_mask[joint_index]: - joint_data[env_index, joint_index] = value[joint_index] - - -@wp.kernel -def update_joint_array_with_value( - value: wp.float32, - joint_data: wp.array2d(dtype=wp.float32), - env_mask: wp.array(dtype=bool), - joint_mask: wp.array(dtype=bool), -): - """Update the joint data for the given environment and joint indices with a value. - - .. note:: The :arg:`env_mask` length must be equal to the number of instances in the newton data. - The :arg:`joint_mask` length must be equal to the number of joints in the newton data. The :arg:`joint_data` shape - must be (num_instances, num_joints). - - Args: - value: The value to update the joint data with. - joint_data: The joint data to update. Shape is (num_instances, num_joints). (modified) - env_mask: The environment mask to update the joint data for. Shape is (num_instances,). - joint_mask: The joint mask to update the joint data for. Shape is (num_joints,). - """ - env_index, joint_index = wp.tid() - if env_mask[env_index] and joint_mask[joint_index]: - joint_data[env_index, joint_index] = value - - -@wp.kernel -def update_joint_array_with_value_int( - value: wp.int32, - joint_data: wp.array2d(dtype=wp.int32), - env_mask: wp.array(dtype=bool), - joint_mask: wp.array(dtype=bool), -): - """Update the joint data for the given environment and joint indices with a value. - - .. note:: The :arg:`env_mask` length must be equal to the number of instances in the newton data. - The :arg:`joint_mask` length must be equal to the number of joints in the newton data. The :arg:`joint_data` shape - must be (num_instances, num_joints). - - Args: - value: The value to update the joint data with. - joint_data: The joint data to update. Shape is (num_instances, num_joints). (modified) - env_mask: The environment mask to update the joint data for. Shape is (num_instances,). - joint_mask: The joint mask to update the joint data for. Shape is (num_joints,). - """ - env_index, joint_index = wp.tid() - if env_mask[env_index] and joint_mask[joint_index]: - joint_data[env_index, joint_index] = value - - -""" -Kernels to update joint limits. -""" - - @wp.func def get_soft_joint_limits(lower_limit: float, upper_limit: float, soft_factor: float) -> wp.vec2f: """Get the soft joint limits for the given lower and upper limits and soft factor. @@ -156,191 +29,6 @@ def get_soft_joint_limits(lower_limit: float, upper_limit: float, soft_factor: f return wp.vec2f(lower_limit, upper_limit) -@wp.kernel -def update_joint_limits( - new_limits_lower: wp.array2d(dtype=wp.float32), - new_limits_upper: wp.array2d(dtype=wp.float32), - soft_factor: float, - lower_limits: wp.array2d(dtype=wp.float32), - upper_limits: wp.array2d(dtype=wp.float32), - soft_joint_limits: wp.array2d(dtype=wp.vec2f), - env_mask: wp.array(dtype=bool), - joint_mask: wp.array(dtype=bool), -): - """Update the joint limits for the given environment and joint indices. - - .. note:: The :arg:`env_mask` length must be equal to the number of instances in the newton data. - The :arg:`joint_mask` length must be equal to the number of joints in the newton data. - - Args: - new_limits_lower: The new lower limits to update the joint limits with. Shape is (num_instances, num_joints). - new_limits_upper: The new upper limits to update the joint limits with. Shape is (num_instances, num_joints). - soft_factor: The soft factor to use for the soft joint limits. - lower_limits: The lower limits to update the joint limits with. Shape is (num_instances, num_joints). (modified) - upper_limits: The upper limits to update the joint limits with. Shape is (num_instances, num_joints). (modified) - soft_joint_limits: The soft joint limits to update. Shape is (num_instances, num_joints). (modified) - env_mask: The environment mask to update the joint limits for. Shape is (num_instances,). - joint_mask: The joint mask to update the joint limits for. Shape is (num_joints,). - """ - env_index, joint_index = wp.tid() - if env_mask[env_index] and joint_mask[joint_index]: - lower_limits[env_index, joint_index] = new_limits_lower[env_index, joint_index] - upper_limits[env_index, joint_index] = new_limits_upper[env_index, joint_index] - - soft_joint_limits[env_index, joint_index] = get_soft_joint_limits( - lower_limits[env_index, joint_index], upper_limits[env_index, joint_index], soft_factor - ) - - -@wp.kernel -def update_joint_limits_with_value( - new_limits: float, - soft_factor: float, - lower_limits: wp.array2d(dtype=wp.float32), - upper_limits: wp.array2d(dtype=wp.float32), - soft_joint_limits: wp.array2d(dtype=wp.vec2f), - env_mask: wp.array(dtype=bool), - joint_mask: wp.array(dtype=bool), -): - """Update the joint limits for the given environment and joint indices with a value. - - .. note:: The :arg:`env_mask` length must be equal to the number of instances in the newton data. - The :arg:`joint_mask` length must be equal to the number of joints in the newton data. - - Args: - new_limits: The new limits to update the joint limits with. - soft_factor: The soft factor to use for the soft joint limits. - lower_limits: The lower limits to update the joint limits with. Shape is (num_instances, num_joints). (modified) - upper_limits: The upper limits to update the joint limits with. Shape is (num_instances, num_joints). (modified) - soft_joint_limits: The soft joint limits to update. Shape is (num_instances, num_joints). (modified) - env_mask: The environment mask to update the joint limits for. Shape is (num_instances,). - joint_mask: The joint mask to update the joint limits for. Shape is (num_joints,). - """ - env_index, joint_index = wp.tid() - if env_mask[env_index] and joint_mask[joint_index]: - lower_limits[env_index, joint_index] = new_limits - upper_limits[env_index, joint_index] = new_limits - - soft_joint_limits[env_index, joint_index] = get_soft_joint_limits( - lower_limits[env_index, joint_index], upper_limits[env_index, joint_index], soft_factor - ) - - -@wp.kernel -def update_joint_limits_value_vec2f( - new_limits: wp.vec2f, - soft_factor: float, - lower_limits: wp.array2d(dtype=wp.float32), - upper_limits: wp.array2d(dtype=wp.float32), - soft_joint_limits: wp.array2d(dtype=wp.vec2f), - env_mask: wp.array(dtype=bool), - joint_mask: wp.array(dtype=bool), -): - """Update the joint limits for the given environment and joint indices with a value. - - Args: - new_limits: The new limits to update the joint limits with. - soft_factor: The soft factor to use for the soft joint limits. - lower_limits: The lower limits to update the joint limits with. Shape is (num_instances, num_joints). (modified) - upper_limits: The upper limits to update the joint limits with. Shape is (num_instances, num_joints). (modified) - soft_joint_limits: The soft joint limits to update. Shape is (num_instances, num_joints). (modified) - env_mask: The environment mask to update the joint limits for. Shape is (num_instances,). - joint_mask: The joint mask to update the joint limits for. Shape is (num_joints,). - """ - env_index, joint_index = wp.tid() - if env_mask[env_index] and joint_mask[joint_index]: - lower_limits[env_index, joint_index] = new_limits[0] - upper_limits[env_index, joint_index] = new_limits[1] - - soft_joint_limits[env_index, joint_index] = get_soft_joint_limits( - lower_limits[env_index, joint_index], upper_limits[env_index, joint_index], soft_factor - ) - - -""" -Kernels to update joint position from joint limits. -""" - - -@wp.kernel -def update_joint_pos_with_limits( - joint_pos_limits_lower: wp.array2d(dtype=wp.float32), - joint_pos_limits_upper: wp.array2d(dtype=wp.float32), - joint_pos: wp.array2d(dtype=wp.float32), - env_mask: wp.array(dtype=bool), - joint_mask: wp.array(dtype=bool), -): - """Update the joint position for the given environment and joint indices with the limits. - - .. note:: The :arg:`env_mask` length must be equal to the number of instances in the newton data. - The :arg:`joint_mask` length must be equal to the number of joints in the newton data. - - Args: - joint_pos_limits_lower: The lower limits to update the joint position with. Shape is (num_instances, num_joints). - joint_pos_limits_upper: The upper limits to update the joint position with. Shape is (num_instances, num_joints). - joint_pos: The joint position to update. Shape is (num_instances, num_joints). (modified) - env_mask: The environment mask to update the joint position for. Shape is (num_instances,). - joint_mask: The joint mask to update the joint position for. Shape is (num_joints,). - """ - env_index, joint_index = wp.tid() - if env_mask[env_index] and joint_mask[joint_index]: - joint_pos[env_index, joint_index] = wp.clamp( - joint_pos[env_index, joint_index], - joint_pos_limits_lower[env_index, joint_index], - joint_pos_limits_upper[env_index, joint_index], - ) - - -@wp.kernel -def update_joint_pos_with_limits_value( - joint_pos_limits: float, - joint_pos: wp.array2d(dtype=wp.float32), - env_mask: wp.array(dtype=bool), - joint_mask: wp.array(dtype=bool), -): - """Update the joint position for the given environment and joint indices with the limits. - - .. note:: The :arg:`env_mask` length must be equal to the number of instances in the newton data. - The :arg:`joint_mask` length must be equal to the number of joints in the newton data. - - Args: - joint_pos_limits: The joint position limits to update. - joint_pos: The joint position to update. Shape is (num_instances, num_joints). (modified) - env_mask: The environment mask to update the joint position for. Shape is (num_instances,). - joint_mask: The joint mask to update the joint position for. Shape is (num_joints,). - """ - env_index, joint_index = wp.tid() - if env_mask[env_index] and joint_mask[joint_index]: - joint_pos[env_index, joint_index] = wp.clamp( - joint_pos[env_index, joint_index], joint_pos_limits, joint_pos_limits - ) - - -@wp.kernel -def update_joint_pos_with_limits_value_vec2f( - joint_pos_limits: wp.vec2f, - joint_pos: wp.array2d(dtype=wp.float32), - env_mask: wp.array(dtype=bool), - joint_mask: wp.array(dtype=bool), -): - """Update the joint position for the given environment and joint indices with the limits. - - .. note:: The :arg:`env_mask` length must be equal to the number of instances in the newton data. - The :arg:`joint_mask` length must be equal to the number of joints in the newton data. - - Args: - joint_pos_limits: The joint position limits to update. Shape is (2,) - joint_pos: The joint position to update. Shape is (num_instances, num_joints). (modified) - env_mask: The environment mask to update the joint position for. Shape is (num_instances,). - joint_mask: The joint mask to update the joint position for. Shape is (num_joints,). - """ - env_index, joint_index = wp.tid() - if env_mask[env_index] and joint_mask[joint_index]: - joint_pos[env_index, joint_index] = wp.clamp( - joint_pos[env_index, joint_index], joint_pos_limits[0], joint_pos_limits[1] - ) - - """ Helper kernel to reconstruct limits """ @@ -420,28 +108,3 @@ def derive_joint_acceleration_from_velocity( # update previous velocity previous_joint_velocity[env_index, joint_index] = joint_velocity[env_index, joint_index] - - -@wp.kernel -def clip_joint_array_with_limits_masked( - lower_limits: wp.array(dtype=wp.float32), - upper_limits: wp.array(dtype=wp.float32), - joint_array: wp.array(dtype=wp.float32), - env_mask: wp.array(dtype=wp.bool), - joint_mask: wp.array(dtype=wp.bool), -): - joint_index = wp.tid() - if env_mask[joint_index] and joint_mask[joint_index]: - joint_array[joint_index] = wp.clamp( - joint_array[joint_index], lower_limits[joint_index], upper_limits[joint_index] - ) - - -@wp.kernel -def clip_joint_array_with_limits( - lower_limits: wp.array(dtype=wp.float32), - upper_limits: wp.array(dtype=wp.float32), - joint_array: wp.array(dtype=wp.float32), -): - index = wp.tid() - joint_array[index] = wp.clamp(joint_array[index], lower_limits[index], upper_limits[index]) diff --git a/source/isaaclab_newton/isaaclab_newton/kernels/other_kernels.py b/source/isaaclab_newton/isaaclab_newton/kernels/other_kernels.py index 1c11f77c4a9..4e2ded9ca76 100644 --- a/source/isaaclab_newton/isaaclab_newton/kernels/other_kernels.py +++ b/source/isaaclab_newton/isaaclab_newton/kernels/other_kernels.py @@ -6,42 +6,20 @@ import warp as wp -@wp.kernel -def update_wrench_array( - new_value: wp.array2d(dtype=wp.spatial_vectorf), - wrench: wp.array2d(dtype=wp.spatial_vectorf), - env_ids: wp.array(dtype=wp.bool), - body_ids: wp.array(dtype=wp.bool), -): - env_index, body_index = wp.tid() - if env_ids[env_index] and body_ids[body_index]: - wrench[env_index, body_index] = new_value[env_index, body_index] - - -@wp.kernel -def update_wrench_array_with_value( - value: wp.spatial_vectorf, - wrench: wp.array2d(dtype=wp.spatial_vectorf), - env_ids: wp.array(dtype=wp.bool), - body_ids: wp.array(dtype=wp.bool), -): - env_index, body_index = wp.tid() - if env_ids[env_index] and body_ids[body_index]: - wrench[env_index, body_index] = value - - @wp.func def update_wrench_with_force( + wrench: wp.spatial_vectorf, force: wp.vec3f, ) -> wp.spatial_vectorf: - return wp.spatial_vectorf(0.0, 0.0, 0.0, force[0], force[1], force[2]) + return wp.spatial_vector(force, wp.spatial_bottom(wrench), wp.float32) @wp.func def update_wrench_with_torque( + wrench: wp.spatial_vectorf, torque: wp.vec3f, ) -> wp.spatial_vectorf: - return wp.spatial_vectorf(torque[0], torque[1], torque[2], 0.0, 0.0, 0.0) + return wp.spatial_vector(wp.spatial_top(wrench), torque, wp.float32) @wp.kernel @@ -53,7 +31,9 @@ def update_wrench_array_with_force( ): env_index, body_index = wp.tid() if env_ids[env_index] and body_ids[body_index]: - wrench[env_index, body_index] = update_wrench_with_force(forces[env_index, body_index]) + wrench[env_index, body_index] = update_wrench_with_force( + wrench[env_index, body_index], forces[env_index, body_index] + ) @wp.kernel @@ -65,7 +45,9 @@ def update_wrench_array_with_torque( ): env_index, body_index = wp.tid() if env_ids[env_index] and body_ids[body_index]: - wrench[env_index, body_index] = update_wrench_with_torque(torques[env_index, body_index]) + wrench[env_index, body_index] = update_wrench_with_torque( + wrench[env_index, body_index], torques[env_index, body_index] + ) @wp.kernel @@ -75,13 +57,3 @@ def generate_mask_from_ids( ): index = wp.tid() mask[ids[index]] = True - - -@wp.kernel -def populate_empty_array( - input_array: wp.array(dtype=wp.float32), - output_array: wp.array(dtype=wp.float32), - indices: wp.array(dtype=wp.int32), -): - index = wp.tid() - output_array[indices[index]] = input_array[index] diff --git a/source/isaaclab_newton/isaaclab_newton/kernels/state_kernels.py b/source/isaaclab_newton/isaaclab_newton/kernels/state_kernels.py index 8e97803d7e7..be7181cfeb0 100644 --- a/source/isaaclab_newton/isaaclab_newton/kernels/state_kernels.py +++ b/source/isaaclab_newton/isaaclab_newton/kernels/state_kernels.py @@ -68,22 +68,6 @@ def split_transform_batched_array_to_quaternion_batched_array( quaternion[index, body_index] = wp.transform_get_rotation(transform[index, body_index]) -@wp.kernel -def generate_pose_from_position_with_unit_quaternion( - position: wp.array(dtype=wp.vec3f), - pose: wp.array(dtype=wp.transformf), -): - """ - Generate a pose from a position with a unit quaternion. - - Args: - position: The position. Shape is (num_instances, 3). - pose: The pose. Shape is (num_instances, 7). (modified) - """ - index = wp.tid() - pose[index] = wp.transformf(position[index], wp.quatf(0.0, 0.0, 0.0, 1.0)) - - @wp.kernel def generate_pose_from_position_with_unit_quaternion_batched( position: wp.array2d(dtype=wp.vec3f), @@ -135,6 +119,20 @@ def split_state_to_velocity( return wp.spatial_vectorf(state[7], state[8], state[9], state[10], state[11], state[12]) +@wp.kernel +def split_state_to_pose_and_velocity( + state: wp.array(dtype=vec13f), + pose: wp.array(dtype=wp.transformf), + velocity: wp.array(dtype=wp.spatial_vectorf), +): + """ + Split a state into a pose and a velocity. + """ + index = wp.tid() + pose[index] = split_state_to_pose(state[index]) + velocity[index] = split_state_to_velocity(state[index]) + + @wp.func def combine_state( pose: wp.transformf, @@ -193,31 +191,6 @@ def combine_pose_and_velocity_to_state( root_state[env_index] = combine_state(root_pose[env_index], root_velocity[env_index]) -@wp.kernel -def combine_pose_and_velocity_to_state_masked( - root_pose: wp.array(dtype=wp.transformf), - root_velocity: wp.array(dtype=wp.spatial_vectorf), - root_state: wp.array(dtype=vec13f), - env_mask: wp.array(dtype=wp.bool), -): - """ - Combine a pose and a velocity into a state. - - The state is given in the following format: (x, y, z, qx, qy, qz, qw, wx, wy, wz, vx, vy, vz). - - .. note:: The quaternion is given in the following format: (qx, qy, qz, qw). - - Args: - pose: The pose. Shape is (num_instances, 7). - velocity: The velocity. Shape is (num_instances, 6). - state: The state. Shape is (num_instances, 13). (modified) - env_mask: The mask of the environments to combine the state for. Shape is (num_instances,). - """ - env_index = wp.tid() - if env_mask[env_index]: - root_state[env_index] = combine_state(root_pose[env_index], root_velocity[env_index]) - - @wp.kernel def combine_pose_and_velocity_to_state_batched( root_pose: wp.array2d(dtype=wp.transformf), @@ -242,35 +215,6 @@ def combine_pose_and_velocity_to_state_batched( ) -@wp.kernel -def combine_pose_and_velocity_to_state_batched_masked( - root_pose: wp.array2d(dtype=wp.transformf), - root_velocity: wp.array2d(dtype=wp.spatial_vectorf), - root_state: wp.array2d(dtype=vec13f), - env_mask: wp.array(dtype=wp.bool), - body_mask: wp.array(dtype=wp.bool), -): - """ - Combine a pose and a velocity into a state. - - The state is given in the following format: (x, y, z, qx, qy, qz, qw, wx, wy, wz, vx, vy, vz). - - .. note:: The quaternion is given in the following format: (qx, qy, qz, qw). - - Args: - pose: The pose. Shape is (num_instances, num_bodies, 7). - velocity: The velocity. Shape is (num_instances, num_bodies, 6). - state: The state. Shape is (num_instances, num_bodies, 13). (modified) - env_mask: The mask of the environments to combine the state for. Shape is (num_instances,). - body_mask: The mask of the bodies to combine the state for. Shape is (num_bodies,). - """ - env_index, body_index = wp.tid() - if env_mask[env_index] and body_mask[body_index]: - root_state[env_index, body_index] = combine_state( - root_pose[env_index, body_index], root_velocity[env_index, body_index] - ) - - """ Frame combination kernels """ @@ -293,29 +237,6 @@ def combine_transforms(p1: wp.vec3f, q1: wp.quatf, p2: wp.vec3f, q2: wp.quatf) - return wp.transformf(p1 + wp.quat_rotate(q1, p2), q1 * q2) -@wp.kernel -def combine_frame_transforms_partial( - pose_1: wp.array(dtype=wp.transformf), - position_2: wp.array(dtype=wp.vec3f), - resulting_pose: wp.array(dtype=wp.transformf), -): - """ - Combine a frame transform with a position. - - Args: - pose_1: The frame transform. Shape is (num_instances, 7). - position_2: The position. Shape is (num_instances, 3). - resulting_pose: The resulting pose. Shape is (num_instances, 7). (modified) - """ - index = wp.tid() - resulting_pose[index] = combine_transforms( - wp.transform_get_translation(pose_1[index]), - wp.transform_get_rotation(pose_1[index]), - position_2[index], - wp.quatf(0.0, 0.0, 0.0, 1.0), - ) - - @wp.kernel def combine_frame_transforms_partial_root( pose_1: wp.array(dtype=wp.transformf), @@ -362,52 +283,6 @@ def combine_frame_transforms_partial_batch( ) -@wp.kernel -def combine_frame_transforms( - pose_1: wp.array(dtype=wp.transformf), - pose_2: wp.array(dtype=wp.transformf), - resulting_pose: wp.array(dtype=wp.transformf), -): - """ - Combine two transforms. - - Args: - pose_1: The first transform. Shape is (1, 7). - pose_2: The second transform. Shape is (1, 7). - resulting_pose: The resulting pose. Shape is (1, 7). (modified) - """ - index = wp.tid() - resulting_pose[index] = combine_transforms( - wp.transform_get_translation(pose_1[index]), - wp.transform_get_rotation(pose_1[index]), - wp.transform_get_translation(pose_2[index]), - wp.transform_get_rotation(pose_2[index]), - ) - - -@wp.kernel -def combine_frame_transforms_batch( - pose_1: wp.array2d(dtype=wp.transformf), - pose_2: wp.array2d(dtype=wp.transformf), - resulting_pose: wp.array2d(dtype=wp.transformf), -): - """ - Combine two transforms. - - Args: - pose_1: The first transform. Shape is (num_instances, 7). - pose_2: The second transform. Shape is (num_instances, 7). - resulting_pose: The resulting pose. Shape is (num_instances, 7). (modified) - """ - env_idx, body_idx = wp.tid() - resulting_pose[env_idx, body_idx] = combine_transforms( - wp.transform_get_translation(pose_1[env_idx, body_idx]), - wp.transform_get_rotation(pose_1[env_idx, body_idx]), - wp.transform_get_translation(pose_2[env_idx, body_idx]), - wp.transform_get_rotation(pose_2[env_idx, body_idx]), - ) - - @wp.kernel def project_vec_from_pose_single( vec: wp.vec3f, pose: wp.array(dtype=wp.transformf), resulting_vec: wp.array(dtype=wp.vec3f) @@ -481,120 +356,11 @@ def compute_heading( heading[index] = heading_vec_b(wp.transform_get_rotation(pose_w[index]), forward_vec_b) -""" -Update kernels -""" - - -@wp.kernel -def update_transforms_array( - new_pose: wp.array(dtype=wp.transformf), - pose: wp.array(dtype=wp.transformf), - env_mask: wp.array(dtype=wp.bool), -): - """ - Update a transforms array. - - Args: - new_pose: The new pose. Shape is (num_instances, 7). - pose: The pose. Shape is (num_instances, 7). (modified) - env_mask: The mask of the environments to update the pose for. Shape is (num_instances,). - """ - index = wp.tid() - if env_mask[index]: - pose[index] = new_pose[index] - - -@wp.kernel -def update_transforms_array_with_value( - value: wp.transformf, - pose: wp.array(dtype=wp.transformf), - env_mask: wp.array(dtype=wp.bool), -): - """ - Update a transforms array with a value. - - Args: - value: The value. Shape is (7,). - pose: The pose. Shape is (num_instances, 7). (modified) - env_mask: The mask of the environments to update the pose for. Shape is (num_instances,). - """ - index = wp.tid() - if env_mask[index]: - pose[index] = value - - -@wp.kernel -def update_spatial_vector_array( - velocity: wp.array(dtype=wp.spatial_vectorf), - new_velocity: wp.array(dtype=wp.spatial_vectorf), - env_mask: wp.array(dtype=wp.bool), -): - """ - Update a spatial vector array. - - Args: - new_velocity: The new velocity. Shape is (num_instances, 6). - velocity: The velocity. Shape is (num_instances, 6). (modified) - env_mask: The mask of the environments to update the velocity for. Shape is (num_instances,). - """ - index = wp.tid() - if env_mask[index]: - velocity[index] = new_velocity[index] - - -@wp.kernel -def update_spatial_vector_array_with_value( - value: wp.spatial_vectorf, - velocity: wp.array(dtype=wp.spatial_vectorf), - env_mask: wp.array(dtype=wp.bool), -): - """ - Update a spatial vector array with a value. - - Args: - value: The value. Shape is (6,). - velocity: The velocity. Shape is (num_instances, 6). (modified) - env_mask: The mask of the environments to update the velocity for. Shape is (num_instances,). - """ - index = wp.tid() - if env_mask[index]: - velocity[index] = value - - """ Transform kernels """ -@wp.kernel -def transform_CoM_pose_to_link_frame_masked( - com_pose_w: wp.array(dtype=wp.transformf), - com_pose_link_frame: wp.array(dtype=wp.transformf), - link_pose_w: wp.array(dtype=wp.transformf), - env_mask: wp.array(dtype=wp.bool), -): - """ - Transform a CoM pose to a link frame. - - - - Args: - com_pose_w: The CoM pose in the world frame. Shape is (num_instances, 7). - com_pose_link_frame: The CoM pose in the link frame. Shape is (num_instances, 7). - link_pose_w: The link pose in the world frame. Shape is (num_instances, 7). (modified) - env_mask: The mask of the environments to transform the CoM pose to the link frame for. Shape is (num_instances,). - """ - index = wp.tid() - if env_mask[index]: - link_pose_w[index] = combine_transforms( - wp.transform_get_translation(com_pose_w[index]), - wp.transform_get_rotation(com_pose_w[index]), - wp.transform_get_translation(com_pose_link_frame[index]), - wp.quatf(0.0, 0.0, 0.0, 1.0), - ) - - @wp.kernel def transform_CoM_pose_to_link_frame_masked_root( com_pose_w: wp.array(dtype=wp.transformf), @@ -618,6 +384,6 @@ def transform_CoM_pose_to_link_frame_masked_root( link_pose_w[index] = combine_transforms( wp.transform_get_translation(com_pose_w[index]), wp.transform_get_rotation(com_pose_w[index]), - com_pos_link_frame[index][0], + -com_pos_link_frame[index][0], wp.quatf(0.0, 0.0, 0.0, 1.0), ) diff --git a/source/isaaclab_newton/isaaclab_newton/kernels/velocity_kernels.py b/source/isaaclab_newton/isaaclab_newton/kernels/velocity_kernels.py index f9638e5f9d7..1696f8a7594 100644 --- a/source/isaaclab_newton/isaaclab_newton/kernels/velocity_kernels.py +++ b/source/isaaclab_newton/isaaclab_newton/kernels/velocity_kernels.py @@ -96,13 +96,12 @@ def velocity_projector( Returns: wp.spatial_vectorf: The projected velocity in the link frame. Shape is (6,). """ - u = wp.spatial_top(com_velocity) - w = wp.spatial_bottom(com_velocity) + wp.cross( + u = wp.spatial_top(com_velocity) + wp.cross( wp.spatial_bottom(com_velocity), wp.quat_rotate(wp.transform_get_rotation(link_pose), -com_position), ) - return wp.spatial_vectorf(u[0], u[1], u[2], w[0], w[1], w[2]) - # return wp.spatial_vector(u, w) --> Do it like that. + w = wp.spatial_bottom(com_velocity) + return wp.spatial_vector(u, w, dtype=wp.float32) @wp.func @@ -129,12 +128,12 @@ def velocity_projector_inv( Returns: wp.spatial_vectorf: The projected velocity in the com frame. Shape is (6,). """ - u = wp.spatial_top(com_velocity) - w = wp.spatial_bottom(com_velocity) + wp.cross( + u = wp.spatial_top(com_velocity) + wp.cross( wp.spatial_bottom(com_velocity), wp.quat_rotate(wp.transform_get_rotation(link_pose), com_position), ) - return wp.spatial_vectorf(u[0], u[1], u[2], w[0], w[1], w[2]) + w = wp.spatial_bottom(com_velocity) + return wp.spatial_vector(u, w, dtype=wp.float32) """ @@ -142,67 +141,6 @@ def velocity_projector_inv( """ -@wp.kernel -def project_com_velocity_to_link_frame( - com_velocity: wp.array(dtype=wp.spatial_vectorf), - link_pose: wp.array(dtype=wp.transformf), - com_position: wp.array(dtype=wp.vec3f), - link_velocity: wp.array(dtype=wp.spatial_vectorf), -): - """ - Project a velocity from the com frame to the link frame. - - Velocities are given in the following format: (wx, wy, wz, vx, vy, vz). - - .. caution:: Velocities are given with angular velocity first and linear velocity second. - - .. note:: Only :arg:`com_position` is needed as in Newton, the CoM orientation is always aligned with the - link frame. - - Args: - com_velocity: The com velocity in the world frame. Shape is (num_links, 6). - link_pose: The link pose in the world frame. Shape is (num_links, 7). - com_position: The com position in link frame. Shape is (num_links, 3). - link_velocity: The link velocity. Shape is (num_links, 6). (modified) - """ - index = wp.tid() - link_velocity[index] = velocity_projector(com_velocity[index], link_pose[index], com_position[index]) - - -@wp.kernel -def project_com_velocity_to_link_frame_masked( - com_velocity: wp.array(dtype=wp.spatial_vectorf), - link_pose: wp.array(dtype=wp.transformf), - com_position: wp.array(dtype=wp.vec3f), - link_velocity: wp.array(dtype=wp.spatial_vectorf), - mask: wp.array(dtype=wp.bool), -): - """ - Project a velocity from the com frame to the link frame. - - Velocities are given in the following format: (wx, wy, wz, vx, vy, vz). - - .. caution:: Velocities are given with angular velocity first and linear velocity second. - - .. note:: Only :arg:`com_position` is needed as in Newton, the CoM orientation is always aligned with the - link frame. - - Args: - com_velocity: The com velocity in the world frame. Shape is (num_links, 6). - link_pose: The link pose in the world frame. Shape is (num_links, 7). - com_position: The com position in link frame. Shape is (num_links, 3). - link_velocity: The link velocity in the world frame. Shape is (num_links, 6). (modified) - mask: The mask of the links to project the velocity to. Shape is (num_links,). - """ - index = wp.tid() - if mask[index]: - link_velocity[index] = velocity_projector( - com_velocity[index], - link_pose[index], - com_position[index], - ) - - @wp.kernel def project_com_velocity_to_link_frame_batch( com_velocity: wp.array2d(dtype=wp.spatial_vectorf), @@ -234,42 +172,6 @@ def project_com_velocity_to_link_frame_batch( ) -@wp.kernel -def project_com_velocity_to_link_frame_batch_masked( - com_velocity: wp.array2d(dtype=wp.spatial_vectorf), - link_pose: wp.array2d(dtype=wp.transformf), - com_position: wp.array2d(dtype=wp.vec3f), - link_velocity: wp.array2d(dtype=wp.spatial_vectorf), - env_mask: wp.array(dtype=wp.bool), - body_mask: wp.array(dtype=wp.bool), -): - """ - Project a velocity from the com frame to the link frame. - - Velocities are given in the following format: (wx, wy, wz, vx, vy, vz). - - .. caution:: Velocities are given with angular velocity first and linear velocity second. - - .. note:: Only :arg:`com_position` is needed as in Newton, the CoM orientation is always aligned with the - link frame. - - Args: - com_velocity: The com velocity in the world frame. Shape is (num_links, 6). - link_pose: The link pose in the world frame. Shape is (num_links, 7). - com_position: The com position in link frame. Shape is (num_links, 3). - link_velocity: The link velocity in the world frame. Shape is (num_links, 6). (modified) - env_mask: The mask of the environments to project the velocity to. Shape is (num_links,). - body_mask: The mask of the bodies to project the velocity to. Shape is (num_links,). - """ - env_idx, body_idx = wp.tid() - if env_mask[env_idx] and body_mask[body_idx]: - link_velocity[env_idx, body_idx] = velocity_projector( - com_velocity[env_idx, body_idx], - link_pose[env_idx, body_idx], - com_position[env_idx, body_idx], - ) - - @wp.kernel def project_com_velocity_to_link_frame_root( com_velocity: wp.array(dtype=wp.spatial_vectorf), @@ -297,67 +199,6 @@ def project_com_velocity_to_link_frame_root( link_velocity[index] = velocity_projector(com_velocity[index], link_pose[index], com_position[index][0]) -@wp.kernel -def project_link_velocity_to_com_frame( - link_velocity: wp.array(dtype=wp.spatial_vectorf), - link_pose: wp.array(dtype=wp.transformf), - com_position: wp.array(dtype=wp.vec3f), - com_velocity: wp.array(dtype=wp.spatial_vectorf), -): - """ - Project a velocity from the link frame to the com frame. - - Velocities are given in the following format: (wx, wy, wz, vx, vy, vz). - - .. caution:: Velocities are given with angular velocity first and linear velocity second. - - .. note:: Only :arg:`com_position` is needed as in Newton, the CoM orientation is always aligned with the - link frame. - - Args: - link_velocity: The link velocity in the world frame. Shape is (num_links, 6). - link_pose: The link pose in the world frame. Shape is (num_links, 7). - com_position: The com position in link frame. Shape is (num_links, 3). - com_velocity: The com velocity in the world frame. Shape is (num_links, 6). (modified) - """ - index = wp.tid() - com_velocity[index] = velocity_projector_inv(link_velocity[index], link_pose[index], com_position[index]) - - -@wp.kernel -def project_link_velocity_to_com_frame_masked( - link_velocity: wp.array(dtype=wp.spatial_vectorf), - link_pose: wp.array(dtype=wp.transformf), - com_position: wp.array(dtype=wp.vec3f), - com_velocity: wp.array(dtype=wp.spatial_vectorf), - mask: wp.array(dtype=wp.bool), -): - """ - Project a velocity from the link frame to the com frame. - - Velocities are given in the following format: (wx, wy, wz, vx, vy, vz). - - .. caution:: Velocities are given with angular velocity first and linear velocity second. - - .. note:: Only :arg:`com_position` is needed as in Newton, the CoM orientation is always aligned with the - link frame. - - Args: - link_velocity: The link velocity in the world frame. Shape is (num_links, 6). - link_pose: The link pose in the world frame. Shape is (num_links, 7). - com_position: The com position in link frame. Shape is (num_links, 3). - com_velocity: The com velocity in the world frame. Shape is (num_links, 6). (modified) - mask: The mask of the links to project the velocity to. Shape is (num_links,). - """ - index = wp.tid() - if mask[index]: - com_velocity[index] = velocity_projector_inv( - link_velocity[index], - link_pose[index], - com_position[index], - ) - - @wp.kernel def project_link_velocity_to_com_frame_masked_root( link_velocity: wp.array(dtype=wp.spatial_vectorf), @@ -392,149 +233,11 @@ def project_link_velocity_to_com_frame_masked_root( ) -@wp.kernel -def project_link_velocity_to_com_frame_batch( - link_velocity: wp.array2d(dtype=wp.spatial_vectorf), - link_pose: wp.array2d(dtype=wp.transformf), - com_position: wp.array2d(dtype=wp.vec3f), - com_velocity: wp.array2d(dtype=wp.spatial_vectorf), -): - """ - Project a velocity from the link frame to the com frame. - - Velocities are given in the following format: (wx, wy, wz, vx, vy, vz). - - .. caution:: Velocities are given with angular velocity first and linear velocity second. - - .. note:: Only :arg:`com_position` is needed as in Newton, the CoM orientation is always aligned with the - link frame. - - Args: - link_velocity (wp.array2d(dtype=wp.spatial_vectorf)): The link velocity in the world frame. - link_pose (wp.array2d(dtype=wp.transformf)): The link pose in the world frame. - com_position (wp.array2d(dtype=wp.vec3f)): The com position in link frame. - com_velocity (wp.array2d(dtype=wp.spatial_vectorf)): The com velocity in the world frame. (destination) - """ - env_idx, body_idx = wp.tid() - com_velocity[env_idx, body_idx] = velocity_projector_inv( - link_velocity[env_idx, body_idx], link_pose[env_idx, body_idx], com_position[env_idx, body_idx] - ) - - -@wp.kernel -def project_link_velocity_to_com_frame_batch_masked( - link_velocity: wp.array2d(dtype=wp.spatial_vectorf), - link_pose: wp.array2d(dtype=wp.transformf), - com_position: wp.array2d(dtype=wp.vec3f), - com_velocity: wp.array2d(dtype=wp.spatial_vectorf), - env_mask: wp.array(dtype=wp.bool), - body_mask: wp.array(dtype=wp.bool), -): - """ - Project a velocity from the link frame to the com frame. - - Velocities are given in the following format: (wx, wy, wz, vx, vy, vz). - - .. caution:: Velocities are given with angular velocity first and linear velocity second. - - .. note:: Only :arg:`com_position` is needed as in Newton, the CoM orientation is always aligned with the - link frame. - - Args: - link_velocity: The link velocity in the world frame. Shape is (num_links, 6). - link_pose: The link pose in the world frame. Shape is (num_links, 7). - com_position: The com position in link frame. Shape is (num_links, 3). - com_velocity: The com velocity in the world frame. Shape is (num_links, 6). (modified) - env_mask: The mask of the environments to project the velocity to. Shape is (num_links,). - body_mask: The mask of the bodies to project the velocity to. Shape is (num_links,). - """ - env_idx, body_idx = wp.tid() - if env_mask[env_idx] and body_mask[body_idx]: - com_velocity[env_idx, body_idx] = velocity_projector_inv( - link_velocity[env_idx, body_idx], - link_pose[env_idx, body_idx], - com_position[env_idx, body_idx], - ) - - -""" -Kernels to update spatial vector arrays -""" - - -@wp.kernel -def update_spatial_vector_array_masked( - new_velocity: wp.array(dtype=wp.spatial_vectorf), - velocity: wp.array(dtype=wp.spatial_vectorf), - mask: wp.array(dtype=wp.bool), -): - """ - Update a velocity array with a new velocity. - - Velocities are given in the following format: (wx, wy, wz, vx, vy, vz). - - .. caution:: Velocities are given with angular velocity first and linear velocity second. - - Args: - new_velocity: The new velocity. Shape is (num_links, 6). - velocity: The velocity array. Shape is (num_links, 6). (modified) - mask: The mask of the velocities to update. Shape is (num_links,). - """ - index = wp.tid() - if mask[index]: - velocity[index] = new_velocity[index] - - -@wp.kernel -def update_spatial_vector_array_batch_masked( - new_velocity: wp.array2d(dtype=wp.spatial_vectorf), - velocity: wp.array2d(dtype=wp.spatial_vectorf), - env_mask: wp.array(dtype=wp.bool), - body_mask: wp.array(dtype=wp.bool), -): - """ - Update a velocity array with a new velocity. - - Velocities are given in the following format: (wx, wy, wz, vx, vy, vz). - - .. caution:: Velocities are given with angular velocity first and linear velocity second. - - Args: - new_velocity: The new velocity. Shape is (num_links, 6). - velocity: The velocity array. Shape is (num_links, 6). (modified) - env_mask: The mask of the environments to update. Shape is (num_links,). - body_mask: The mask of the bodies to update. Shape is (num_links,). - """ - env_idx, body_idx = wp.tid() - if env_mask[env_idx] and body_mask[body_idx]: - velocity[env_idx, body_idx] = new_velocity[env_idx, body_idx] - - """ Kernels to derive body acceleration from velocity. """ -@wp.kernel -def derive_body_acceleration_from_velocity( - velocity: wp.array(dtype=wp.spatial_vectorf), - previous_velocity: wp.array(dtype=wp.spatial_vectorf), - dt: float, - acceleration: wp.array(dtype=wp.spatial_vectorf), -): - """ - Derive the body acceleration from the velocity. - - Args: - velocity: The velocity. Shape is (num_instances, 6). - previous_velocity: The previous velocity. Shape is (num_instances, 6). - dt: The time step. - acceleration: The acceleration. Shape is (num_instances, 6). (modified) - """ - env_idx = wp.tid() - acceleration[env_idx] = (velocity[env_idx] - previous_velocity[env_idx]) / dt - - @wp.kernel def derive_body_acceleration_from_velocity_batched( velocity: wp.array2d(dtype=wp.spatial_vectorf), diff --git a/source/isaaclab_newton/test/assets/articulation/__init__.py b/source/isaaclab_newton/test/assets/articulation/__init__.py new file mode 100644 index 00000000000..2e924fbf1b1 --- /dev/null +++ b/source/isaaclab_newton/test/assets/articulation/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause diff --git a/source/isaaclab_newton/test/assets/articulation/benchmark_articulation.py b/source/isaaclab_newton/test/assets/articulation/benchmark_articulation.py new file mode 100644 index 00000000000..b256ec594db --- /dev/null +++ b/source/isaaclab_newton/test/assets/articulation/benchmark_articulation.py @@ -0,0 +1,1395 @@ +# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Micro-benchmarking framework for Articulation class. + +This module provides a benchmarking framework to measure the performance of setter and writer +methods in the Articulation class. Each method is benchmarked under two scenarios: + +1. **Best Case (Warp)**: Inputs are Warp arrays with masks - this is the optimal path that + avoids any data conversion overhead. + +2. **Worst Case (Torch)**: Inputs are PyTorch tensors with indices - this path requires + conversion from Torch to Warp and from indices to masks. + +Usage: + python benchmark_articulation.py [--num_iterations N] [--warmup_steps W] [--num_instances I] [--num_bodies B] [--num_joints J] + +Example: + python benchmark_articulation.py --num_iterations 1000 --warmup_steps 10 + python benchmark_articulation.py --mode warp # Only run Warp benchmarks + python benchmark_articulation.py --mode torch # Only run Torch benchmarks +""" + +from __future__ import annotations + +import argparse +import contextlib +import numpy as np +import time +import torch +import warnings +from collections.abc import Callable +from dataclasses import dataclass +from enum import Enum +from unittest.mock import MagicMock, patch + +import warp as wp +from isaaclab_newton.assets.articulation.articulation import Articulation +from isaaclab_newton.assets.articulation.articulation_data import ArticulationData +from isaaclab_newton.kernels import vec13f + +# Import mock classes from shared module +from mock_interface import MockNewtonArticulationView, MockNewtonModel + +from isaaclab.assets.articulation.articulation_cfg import ArticulationCfg + +# Initialize Warp +wp.init() + +# Suppress deprecation warnings during benchmarking +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=UserWarning) + + +class InputMode(Enum): + """Input mode for benchmarks.""" + + WARP = "warp" + TORCH = "torch" + + +def get_git_info() -> dict: + """Get git repository information. + + Returns: + Dictionary containing git commit hash, branch, and other info. + """ + import os + import subprocess + + git_info = { + "commit_hash": "Unknown", + "commit_hash_short": "Unknown", + "branch": "Unknown", + "commit_date": "Unknown", + } + + script_dir = os.path.dirname(os.path.abspath(__file__)) + + try: + result = subprocess.run( + ["git", "rev-parse", "HEAD"], + cwd=script_dir, + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + git_info["commit_hash"] = result.stdout.strip() + git_info["commit_hash_short"] = result.stdout.strip()[:8] + + result = subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + cwd=script_dir, + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + git_info["branch"] = result.stdout.strip() + + result = subprocess.run( + ["git", "log", "-1", "--format=%ci"], + cwd=script_dir, + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + git_info["commit_date"] = result.stdout.strip() + + except Exception: + pass + + return git_info + + +def get_hardware_info() -> dict: + """Gather hardware information for the benchmark. + + Returns: + Dictionary containing CPU, GPU, and memory information. + """ + import os + import platform + + hardware_info = { + "cpu": { + "name": platform.processor() or "Unknown", + "physical_cores": os.cpu_count(), + }, + "gpu": {}, + "memory": {}, + "system": { + "platform": platform.system(), + "platform_release": platform.release(), + "platform_version": platform.version(), + "architecture": platform.machine(), + "python_version": platform.python_version(), + }, + } + + # Try to get more detailed CPU info on Linux + with contextlib.suppress(Exception): + with open("/proc/cpuinfo") as f: + cpuinfo = f.read() + for line in cpuinfo.split("\n"): + if "model name" in line: + hardware_info["cpu"]["name"] = line.split(":")[1].strip() + break + + # Memory info + try: + with open("/proc/meminfo") as f: + meminfo = f.read() + for line in meminfo.split("\n"): + if "MemTotal" in line: + mem_kb = int(line.split()[1]) + hardware_info["memory"]["total_gb"] = round(mem_kb / (1024 * 1024), 2) + break + except Exception: + try: + import psutil + + mem = psutil.virtual_memory() + hardware_info["memory"]["total_gb"] = round(mem.total / (1024**3), 2) + except ImportError: + hardware_info["memory"]["total_gb"] = "Unknown" + + # GPU info using PyTorch + if torch.cuda.is_available(): + hardware_info["gpu"]["available"] = True + hardware_info["gpu"]["count"] = torch.cuda.device_count() + hardware_info["gpu"]["devices"] = [] + + for i in range(torch.cuda.device_count()): + gpu_props = torch.cuda.get_device_properties(i) + hardware_info["gpu"]["devices"].append({ + "index": i, + "name": gpu_props.name, + "total_memory_gb": round(gpu_props.total_memory / (1024**3), 2), + "compute_capability": f"{gpu_props.major}.{gpu_props.minor}", + "multi_processor_count": gpu_props.multi_processor_count, + }) + + current_device = torch.cuda.current_device() + hardware_info["gpu"]["current_device"] = current_device + hardware_info["gpu"]["current_device_name"] = torch.cuda.get_device_name(current_device) + else: + hardware_info["gpu"]["available"] = False + + hardware_info["gpu"]["pytorch_version"] = torch.__version__ + if torch.cuda.is_available(): + try: + import torch.version as torch_version + + cuda_version = getattr(torch_version, "cuda", None) + hardware_info["gpu"]["cuda_version"] = cuda_version if cuda_version else "Unknown" + except Exception: + hardware_info["gpu"]["cuda_version"] = "Unknown" + else: + hardware_info["gpu"]["cuda_version"] = "N/A" + + try: + warp_version = getattr(wp.config, "version", None) + hardware_info["warp"] = {"version": warp_version if warp_version else "Unknown"} + except Exception: + hardware_info["warp"] = {"version": "Unknown"} + + return hardware_info + + +def print_hardware_info(hardware_info: dict): + """Print hardware information to console.""" + print("\n" + "=" * 80) + print("HARDWARE INFORMATION") + print("=" * 80) + + sys_info = hardware_info.get("system", {}) + print(f"\nSystem: {sys_info.get('platform', 'Unknown')} {sys_info.get('platform_release', '')}") + print(f"Python: {sys_info.get('python_version', 'Unknown')}") + + cpu_info = hardware_info.get("cpu", {}) + print(f"\nCPU: {cpu_info.get('name', 'Unknown')}") + print(f" Cores: {cpu_info.get('physical_cores', 'Unknown')}") + + mem_info = hardware_info.get("memory", {}) + print(f"\nMemory: {mem_info.get('total_gb', 'Unknown')} GB") + + gpu_info = hardware_info.get("gpu", {}) + if gpu_info.get("available", False): + print(f"\nGPU: {gpu_info.get('current_device_name', 'Unknown')}") + for device in gpu_info.get("devices", []): + print(f" [{device['index']}] {device['name']}") + print(f" Memory: {device['total_memory_gb']} GB") + print(f" Compute Capability: {device['compute_capability']}") + print(f" SM Count: {device['multi_processor_count']}") + print(f"\nPyTorch: {gpu_info.get('pytorch_version', 'Unknown')}") + print(f"CUDA: {gpu_info.get('cuda_version', 'Unknown')}") + else: + print("\nGPU: Not available") + + warp_info = hardware_info.get("warp", {}) + print(f"Warp: {warp_info.get('version', 'Unknown')}") + + repo_info = get_git_info() + print("\nRepository:") + print(f" Commit: {repo_info.get('commit_hash_short', 'Unknown')}") + print(f" Branch: {repo_info.get('branch', 'Unknown')}") + print(f" Date: {repo_info.get('commit_date', 'Unknown')}") + print("=" * 80) + + +@dataclass +class BenchmarkConfig: + """Configuration for the benchmarking framework.""" + + num_iterations: int = 1000 + """Number of iterations to run each function.""" + + warmup_steps: int = 10 + """Number of warmup steps before timing.""" + + num_instances: int = 4096 + """Number of articulation instances.""" + + num_bodies: int = 12 + """Number of bodies per articulation.""" + + num_joints: int = 11 + """Number of joints per articulation.""" + + device: str = "cuda:0" + """Device to run benchmarks on.""" + + mode: str = "both" + """Benchmark mode: 'warp', 'torch', or 'both'.""" + + +@dataclass +class BenchmarkResult: + """Result of a single benchmark.""" + + name: str + """Name of the benchmarked method.""" + + mode: InputMode + """Input mode used (WARP or TORCH).""" + + mean_time_us: float + """Mean execution time in microseconds.""" + + std_time_us: float + """Standard deviation of execution time in microseconds.""" + + num_iterations: int + """Number of iterations run.""" + + skipped: bool = False + """Whether the benchmark was skipped.""" + + skip_reason: str = "" + """Reason for skipping the benchmark.""" + + +@dataclass +class MethodBenchmark: + """Definition of a method to benchmark.""" + + name: str + """Name of the method.""" + + method_name: str + """Actual method name on the Articulation class.""" + + input_generator_warp: Callable + """Function to generate Warp inputs.""" + + input_generator_torch: Callable + """Function to generate Torch inputs.""" + + category: str = "general" + """Category of the method (e.g., 'root_state', 'joint_state', 'joint_params').""" + + +def create_test_articulation( + num_instances: int = 2, + num_joints: int = 6, + num_bodies: int = 7, + device: str = "cuda:0", +) -> tuple[Articulation, MockNewtonArticulationView, MagicMock]: + """Create a test Articulation instance with mocked dependencies.""" + joint_names = [f"joint_{i}" for i in range(num_joints)] + body_names = [f"body_{i}" for i in range(num_bodies)] + + articulation = object.__new__(Articulation) + + articulation.cfg = ArticulationCfg( + prim_path="/World/Robot", + soft_joint_pos_limit_factor=1.0, + actuators={}, + ) + + mock_view = MockNewtonArticulationView( + num_instances=num_instances, + num_bodies=num_bodies, + num_joints=num_joints, + device=device, + is_fixed_base=False, + joint_names=joint_names, + body_names=body_names, + ) + mock_view.set_mock_data() + + object.__setattr__(articulation, "_root_view", mock_view) + object.__setattr__(articulation, "_device", device) + + mock_newton_manager = MagicMock() + mock_model = MockNewtonModel() + mock_state = MagicMock() + mock_control = MagicMock() + mock_newton_manager.get_model.return_value = mock_model + mock_newton_manager.get_state_0.return_value = mock_state + mock_newton_manager.get_control.return_value = mock_control + mock_newton_manager.get_dt.return_value = 0.01 + + with patch("isaaclab_newton.assets.articulation.articulation_data.NewtonManager", mock_newton_manager): + data = ArticulationData(mock_view, device) + object.__setattr__(articulation, "_data", data) + + return articulation, mock_view, mock_newton_manager + + +# ============================================================================= +# Input Generators +# ============================================================================= + + +def make_warp_env_mask(num_instances: int, device: str) -> wp.array: + """Create an all-true environment mask.""" + return wp.ones((num_instances,), dtype=wp.bool, device=device) + + +def make_warp_joint_mask(num_joints: int, device: str) -> wp.array: + """Create an all-true joint mask.""" + return wp.ones((num_joints,), dtype=wp.bool, device=device) + + +def make_warp_body_mask(num_bodies: int, device: str) -> wp.array: + """Create an all-true body mask.""" + return wp.ones((num_bodies,), dtype=wp.bool, device=device) + + +# --- Root Link Pose --- +def gen_root_link_pose_warp(config: BenchmarkConfig) -> dict: + """Generate Warp inputs for write_root_link_pose_to_sim.""" + return { + "pose": wp.from_torch( + torch.rand(config.num_instances, 7, device=config.device, dtype=torch.float32), + dtype=wp.transformf, + ), + "env_mask": make_warp_env_mask(config.num_instances, config.device), + } + + +def gen_root_link_pose_torch(config: BenchmarkConfig) -> dict: + """Generate Torch inputs for write_root_link_pose_to_sim.""" + return { + "pose": torch.rand(config.num_instances, 7, device=config.device, dtype=torch.float32), + "env_ids": list(range(config.num_instances)), + } + + +# --- Root COM Pose --- +def gen_root_com_pose_warp(config: BenchmarkConfig) -> dict: + """Generate Warp inputs for write_root_com_pose_to_sim.""" + return { + "root_pose": wp.from_torch( + torch.rand(config.num_instances, 7, device=config.device, dtype=torch.float32), + dtype=wp.transformf, + ), + "env_mask": make_warp_env_mask(config.num_instances, config.device), + } + + +def gen_root_com_pose_torch(config: BenchmarkConfig) -> dict: + """Generate Torch inputs for write_root_com_pose_to_sim.""" + return { + "root_pose": torch.rand(config.num_instances, 7, device=config.device, dtype=torch.float32), + "env_ids": list(range(config.num_instances)), + } + + +# --- Root Link Velocity --- +def gen_root_link_velocity_warp(config: BenchmarkConfig) -> dict: + """Generate Warp inputs for write_root_link_velocity_to_sim.""" + return { + "root_velocity": wp.from_torch( + torch.rand(config.num_instances, 6, device=config.device, dtype=torch.float32), + dtype=wp.spatial_vectorf, + ), + "env_mask": make_warp_env_mask(config.num_instances, config.device), + } + + +def gen_root_link_velocity_torch(config: BenchmarkConfig) -> dict: + """Generate Torch inputs for write_root_link_velocity_to_sim.""" + return { + "root_velocity": torch.rand(config.num_instances, 6, device=config.device, dtype=torch.float32), + "env_ids": list(range(config.num_instances)), + } + + +# --- Root COM Velocity --- +def gen_root_com_velocity_warp(config: BenchmarkConfig) -> dict: + """Generate Warp inputs for write_root_com_velocity_to_sim.""" + return { + "root_velocity": wp.from_torch( + torch.rand(config.num_instances, 6, device=config.device, dtype=torch.float32), + dtype=wp.spatial_vectorf, + ), + "env_mask": make_warp_env_mask(config.num_instances, config.device), + } + + +def gen_root_com_velocity_torch(config: BenchmarkConfig) -> dict: + """Generate Torch inputs for write_root_com_velocity_to_sim.""" + return { + "root_velocity": torch.rand(config.num_instances, 6, device=config.device, dtype=torch.float32), + "env_ids": list(range(config.num_instances)), + } + + +# --- Root State (Deprecated) --- +def gen_root_state_warp(config: BenchmarkConfig) -> dict: + """Generate Warp inputs for write_root_state_to_sim.""" + return { + "root_state": wp.from_torch( + torch.rand(config.num_instances, 13, device=config.device, dtype=torch.float32), + dtype=vec13f, + ), + "env_mask": make_warp_env_mask(config.num_instances, config.device), + } + + +def gen_root_state_torch(config: BenchmarkConfig) -> dict: + """Generate Torch inputs for write_root_state_to_sim.""" + return { + "root_state": torch.rand(config.num_instances, 13, device=config.device, dtype=torch.float32), + "env_ids": list(range(config.num_instances)), + } + + +# --- Root COM State (Deprecated) --- +def gen_root_com_state_warp(config: BenchmarkConfig) -> dict: + """Generate Warp inputs for write_root_com_state_to_sim.""" + return { + "root_state": wp.from_torch( + torch.rand(config.num_instances, 13, device=config.device, dtype=torch.float32), + dtype=vec13f, + ), + "env_mask": make_warp_env_mask(config.num_instances, config.device), + } + + +def gen_root_com_state_torch(config: BenchmarkConfig) -> dict: + """Generate Torch inputs for write_root_com_state_to_sim.""" + return { + "root_state": torch.rand(config.num_instances, 13, device=config.device, dtype=torch.float32), + "env_ids": list(range(config.num_instances)), + } + + +# --- Root Link State (Deprecated) --- +def gen_root_link_state_warp(config: BenchmarkConfig) -> dict: + """Generate Warp inputs for write_root_link_state_to_sim.""" + return { + "root_state": wp.from_torch( + torch.rand(config.num_instances, 13, device=config.device, dtype=torch.float32), + dtype=vec13f, + ), + "env_mask": make_warp_env_mask(config.num_instances, config.device), + } + + +def gen_root_link_state_torch(config: BenchmarkConfig) -> dict: + """Generate Torch inputs for write_root_link_state_to_sim.""" + return { + "root_state": torch.rand(config.num_instances, 13, device=config.device, dtype=torch.float32), + "env_ids": list(range(config.num_instances)), + } + + +# --- Joint State --- +def gen_joint_state_warp(config: BenchmarkConfig) -> dict: + """Generate Warp inputs for write_joint_state_to_sim.""" + return { + "position": wp.from_torch( + torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32), + dtype=wp.float32, + ), + "velocity": wp.from_torch( + torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32), + dtype=wp.float32, + ), + "env_mask": make_warp_env_mask(config.num_instances, config.device), + "joint_mask": make_warp_joint_mask(config.num_joints, config.device), + } + + +def gen_joint_state_torch(config: BenchmarkConfig) -> dict: + """Generate Torch inputs for write_joint_state_to_sim.""" + return { + "position": torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32), + "velocity": torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32), + "env_ids": list(range(config.num_instances)), + "joint_ids": list(range(config.num_joints)), + } + + +# --- Joint Position --- +def gen_joint_position_warp(config: BenchmarkConfig) -> dict: + """Generate Warp inputs for write_joint_position_to_sim.""" + return { + "position": wp.from_torch( + torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32), + dtype=wp.float32, + ), + "env_mask": make_warp_env_mask(config.num_instances, config.device), + "joint_mask": make_warp_joint_mask(config.num_joints, config.device), + } + + +def gen_joint_position_torch(config: BenchmarkConfig) -> dict: + """Generate Torch inputs for write_joint_position_to_sim.""" + return { + "position": torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32), + "env_ids": list(range(config.num_instances)), + "joint_ids": list(range(config.num_joints)), + } + + +# --- Joint Velocity --- +def gen_joint_velocity_warp(config: BenchmarkConfig) -> dict: + """Generate Warp inputs for write_joint_velocity_to_sim.""" + return { + "velocity": wp.from_torch( + torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32), + dtype=wp.float32, + ), + "env_mask": make_warp_env_mask(config.num_instances, config.device), + "joint_mask": make_warp_joint_mask(config.num_joints, config.device), + } + + +def gen_joint_velocity_torch(config: BenchmarkConfig) -> dict: + """Generate Torch inputs for write_joint_velocity_to_sim.""" + return { + "velocity": torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32), + "env_ids": list(range(config.num_instances)), + "joint_ids": list(range(config.num_joints)), + } + + +# --- Joint Stiffness --- +def gen_joint_stiffness_warp(config: BenchmarkConfig) -> dict: + """Generate Warp inputs for write_joint_stiffness_to_sim.""" + return { + "stiffness": wp.from_torch( + torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32), + dtype=wp.float32, + ), + "env_mask": make_warp_env_mask(config.num_instances, config.device), + "joint_mask": make_warp_joint_mask(config.num_joints, config.device), + } + + +def gen_joint_stiffness_torch(config: BenchmarkConfig) -> dict: + """Generate Torch inputs for write_joint_stiffness_to_sim.""" + return { + "stiffness": torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32), + "env_ids": list(range(config.num_instances)), + "joint_ids": list(range(config.num_joints)), + } + + +# --- Joint Damping --- +def gen_joint_damping_warp(config: BenchmarkConfig) -> dict: + """Generate Warp inputs for write_joint_damping_to_sim.""" + return { + "damping": wp.from_torch( + torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32), + dtype=wp.float32, + ), + "env_mask": make_warp_env_mask(config.num_instances, config.device), + "joint_mask": make_warp_joint_mask(config.num_joints, config.device), + } + + +def gen_joint_damping_torch(config: BenchmarkConfig) -> dict: + """Generate Torch inputs for write_joint_damping_to_sim.""" + return { + "damping": torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32), + "env_ids": list(range(config.num_instances)), + "joint_ids": list(range(config.num_joints)), + } + + +# --- Joint Position Limit --- +def gen_joint_position_limit_warp(config: BenchmarkConfig) -> dict: + """Generate Warp inputs for write_joint_position_limit_to_sim.""" + return { + "lower_limits": wp.from_torch( + torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32) * -3.14, + dtype=wp.float32, + ), + "upper_limits": wp.from_torch( + torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32) * 3.14, + dtype=wp.float32, + ), + "env_mask": make_warp_env_mask(config.num_instances, config.device), + "joint_mask": make_warp_joint_mask(config.num_joints, config.device), + } + + +def gen_joint_position_limit_torch(config: BenchmarkConfig) -> dict: + """Generate Torch inputs for write_joint_position_limit_to_sim.""" + return { + "lower_limits": ( + torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32) * -3.14 + ), + "upper_limits": ( + torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32) * 3.14 + ), + "env_ids": list(range(config.num_instances)), + "joint_ids": list(range(config.num_joints)), + } + + +# --- Joint Velocity Limit --- +def gen_joint_velocity_limit_warp(config: BenchmarkConfig) -> dict: + """Generate Warp inputs for write_joint_velocity_limit_to_sim.""" + return { + "limits": wp.from_torch( + torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32) * 10.0, + dtype=wp.float32, + ), + "env_mask": make_warp_env_mask(config.num_instances, config.device), + "joint_mask": make_warp_joint_mask(config.num_joints, config.device), + } + + +def gen_joint_velocity_limit_torch(config: BenchmarkConfig) -> dict: + """Generate Torch inputs for write_joint_velocity_limit_to_sim.""" + return { + "limits": torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32) * 10.0, + "env_ids": list(range(config.num_instances)), + "joint_ids": list(range(config.num_joints)), + } + + +# --- Joint Effort Limit --- +def gen_joint_effort_limit_warp(config: BenchmarkConfig) -> dict: + """Generate Warp inputs for write_joint_effort_limit_to_sim.""" + return { + "limits": wp.from_torch( + torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32) * 100.0, + dtype=wp.float32, + ), + "env_mask": make_warp_env_mask(config.num_instances, config.device), + "joint_mask": make_warp_joint_mask(config.num_joints, config.device), + } + + +def gen_joint_effort_limit_torch(config: BenchmarkConfig) -> dict: + """Generate Torch inputs for write_joint_effort_limit_to_sim.""" + return { + "limits": ( + torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32) * 100.0 + ), + "env_ids": list(range(config.num_instances)), + "joint_ids": list(range(config.num_joints)), + } + + +# --- Joint Armature --- +def gen_joint_armature_warp(config: BenchmarkConfig) -> dict: + """Generate Warp inputs for write_joint_armature_to_sim.""" + return { + "armature": wp.from_torch( + torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32) * 0.1, + dtype=wp.float32, + ), + "env_mask": make_warp_env_mask(config.num_instances, config.device), + "joint_mask": make_warp_joint_mask(config.num_joints, config.device), + } + + +def gen_joint_armature_torch(config: BenchmarkConfig) -> dict: + """Generate Torch inputs for write_joint_armature_to_sim.""" + return { + "armature": ( + torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32) * 0.1 + ), + "env_ids": list(range(config.num_instances)), + "joint_ids": list(range(config.num_joints)), + } + + +# --- Joint Friction Coefficient --- +def gen_joint_friction_coefficient_warp(config: BenchmarkConfig) -> dict: + """Generate Warp inputs for write_joint_friction_coefficient_to_sim.""" + return { + "joint_friction_coeff": wp.from_torch( + torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32) * 0.5, + dtype=wp.float32, + ), + "env_mask": make_warp_env_mask(config.num_instances, config.device), + "joint_mask": make_warp_joint_mask(config.num_joints, config.device), + } + + +def gen_joint_friction_coefficient_torch(config: BenchmarkConfig) -> dict: + """Generate Torch inputs for write_joint_friction_coefficient_to_sim.""" + return { + "joint_friction_coeff": ( + torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32) * 0.5 + ), + "env_ids": list(range(config.num_instances)), + "joint_ids": list(range(config.num_joints)), + } + + +# --- Set Joint Position Target --- +def gen_set_joint_position_target_warp(config: BenchmarkConfig) -> dict: + """Generate Warp inputs for set_joint_position_target.""" + return { + "target": wp.from_torch( + torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32), + dtype=wp.float32, + ), + "env_mask": make_warp_env_mask(config.num_instances, config.device), + "joint_mask": make_warp_joint_mask(config.num_joints, config.device), + } + + +def gen_set_joint_position_target_torch(config: BenchmarkConfig) -> dict: + """Generate Torch inputs for set_joint_position_target.""" + return { + "target": torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32), + "env_ids": list(range(config.num_instances)), + "joint_ids": list(range(config.num_joints)), + } + + +# --- Set Joint Velocity Target --- +def gen_set_joint_velocity_target_warp(config: BenchmarkConfig) -> dict: + """Generate Warp inputs for set_joint_velocity_target.""" + return { + "target": wp.from_torch( + torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32), + dtype=wp.float32, + ), + "env_mask": make_warp_env_mask(config.num_instances, config.device), + "joint_mask": make_warp_joint_mask(config.num_joints, config.device), + } + + +def gen_set_joint_velocity_target_torch(config: BenchmarkConfig) -> dict: + """Generate Torch inputs for set_joint_velocity_target.""" + return { + "target": torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32), + "env_ids": list(range(config.num_instances)), + "joint_ids": list(range(config.num_joints)), + } + + +# --- Set Joint Effort Target --- +def gen_set_joint_effort_target_warp(config: BenchmarkConfig) -> dict: + """Generate Warp inputs for set_joint_effort_target.""" + return { + "target": wp.from_torch( + torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32), + dtype=wp.float32, + ), + "env_mask": make_warp_env_mask(config.num_instances, config.device), + "joint_mask": make_warp_joint_mask(config.num_joints, config.device), + } + + +def gen_set_joint_effort_target_torch(config: BenchmarkConfig) -> dict: + """Generate Torch inputs for set_joint_effort_target.""" + return { + "target": torch.rand(config.num_instances, config.num_joints, device=config.device, dtype=torch.float32), + "env_ids": list(range(config.num_instances)), + "joint_ids": list(range(config.num_joints)), + } + + +# ============================================================================= +# Method Benchmark Definitions +# ============================================================================= + +METHOD_BENCHMARKS = [ + # Root State Writers + MethodBenchmark( + name="write_root_link_pose_to_sim", + method_name="write_root_link_pose_to_sim", + input_generator_warp=gen_root_link_pose_warp, + input_generator_torch=gen_root_link_pose_torch, + category="root_state", + ), + MethodBenchmark( + name="write_root_com_pose_to_sim", + method_name="write_root_com_pose_to_sim", + input_generator_warp=gen_root_com_pose_warp, + input_generator_torch=gen_root_com_pose_torch, + category="root_state", + ), + MethodBenchmark( + name="write_root_link_velocity_to_sim", + method_name="write_root_link_velocity_to_sim", + input_generator_warp=gen_root_link_velocity_warp, + input_generator_torch=gen_root_link_velocity_torch, + category="root_state", + ), + MethodBenchmark( + name="write_root_com_velocity_to_sim", + method_name="write_root_com_velocity_to_sim", + input_generator_warp=gen_root_com_velocity_warp, + input_generator_torch=gen_root_com_velocity_torch, + category="root_state", + ), + # Root State Writers (Deprecated) + MethodBenchmark( + name="write_root_state_to_sim (deprecated)", + method_name="write_root_state_to_sim", + input_generator_warp=gen_root_state_warp, + input_generator_torch=gen_root_state_torch, + category="root_state_deprecated", + ), + MethodBenchmark( + name="write_root_com_state_to_sim (deprecated)", + method_name="write_root_com_state_to_sim", + input_generator_warp=gen_root_com_state_warp, + input_generator_torch=gen_root_com_state_torch, + category="root_state_deprecated", + ), + MethodBenchmark( + name="write_root_link_state_to_sim (deprecated)", + method_name="write_root_link_state_to_sim", + input_generator_warp=gen_root_link_state_warp, + input_generator_torch=gen_root_link_state_torch, + category="root_state_deprecated", + ), + # Joint State Writers + MethodBenchmark( + name="write_joint_state_to_sim", + method_name="write_joint_state_to_sim", + input_generator_warp=gen_joint_state_warp, + input_generator_torch=gen_joint_state_torch, + category="joint_state", + ), + MethodBenchmark( + name="write_joint_position_to_sim", + method_name="write_joint_position_to_sim", + input_generator_warp=gen_joint_position_warp, + input_generator_torch=gen_joint_position_torch, + category="joint_state", + ), + MethodBenchmark( + name="write_joint_velocity_to_sim", + method_name="write_joint_velocity_to_sim", + input_generator_warp=gen_joint_velocity_warp, + input_generator_torch=gen_joint_velocity_torch, + category="joint_state", + ), + # Joint Parameter Writers + MethodBenchmark( + name="write_joint_stiffness_to_sim", + method_name="write_joint_stiffness_to_sim", + input_generator_warp=gen_joint_stiffness_warp, + input_generator_torch=gen_joint_stiffness_torch, + category="joint_params", + ), + MethodBenchmark( + name="write_joint_damping_to_sim", + method_name="write_joint_damping_to_sim", + input_generator_warp=gen_joint_damping_warp, + input_generator_torch=gen_joint_damping_torch, + category="joint_params", + ), + MethodBenchmark( + name="write_joint_position_limit_to_sim", + method_name="write_joint_position_limit_to_sim", + input_generator_warp=gen_joint_position_limit_warp, + input_generator_torch=gen_joint_position_limit_torch, + category="joint_params", + ), + MethodBenchmark( + name="write_joint_velocity_limit_to_sim", + method_name="write_joint_velocity_limit_to_sim", + input_generator_warp=gen_joint_velocity_limit_warp, + input_generator_torch=gen_joint_velocity_limit_torch, + category="joint_params", + ), + MethodBenchmark( + name="write_joint_effort_limit_to_sim", + method_name="write_joint_effort_limit_to_sim", + input_generator_warp=gen_joint_effort_limit_warp, + input_generator_torch=gen_joint_effort_limit_torch, + category="joint_params", + ), + MethodBenchmark( + name="write_joint_armature_to_sim", + method_name="write_joint_armature_to_sim", + input_generator_warp=gen_joint_armature_warp, + input_generator_torch=gen_joint_armature_torch, + category="joint_params", + ), + MethodBenchmark( + name="write_joint_friction_coefficient_to_sim", + method_name="write_joint_friction_coefficient_to_sim", + input_generator_warp=gen_joint_friction_coefficient_warp, + input_generator_torch=gen_joint_friction_coefficient_torch, + category="joint_params", + ), + # Target Setters + MethodBenchmark( + name="set_joint_position_target", + method_name="set_joint_position_target", + input_generator_warp=gen_set_joint_position_target_warp, + input_generator_torch=gen_set_joint_position_target_torch, + category="targets", + ), + MethodBenchmark( + name="set_joint_velocity_target", + method_name="set_joint_velocity_target", + input_generator_warp=gen_set_joint_velocity_target_warp, + input_generator_torch=gen_set_joint_velocity_target_torch, + category="targets", + ), + MethodBenchmark( + name="set_joint_effort_target", + method_name="set_joint_effort_target", + input_generator_warp=gen_set_joint_effort_target_warp, + input_generator_torch=gen_set_joint_effort_target_torch, + category="targets", + ), +] + + +def benchmark_method( + articulation: Articulation, + method_benchmark: MethodBenchmark, + mode: InputMode, + config: BenchmarkConfig, +) -> BenchmarkResult: + """Benchmark a single method of Articulation. + + Args: + articulation: The Articulation instance. + method_benchmark: The method benchmark definition. + mode: Input mode (WARP or TORCH). + config: Benchmark configuration. + + Returns: + BenchmarkResult with timing statistics. + """ + method_name = method_benchmark.method_name + + # Check if method exists + if not hasattr(articulation, method_name): + return BenchmarkResult( + name=method_benchmark.name, + mode=mode, + mean_time_us=0.0, + std_time_us=0.0, + num_iterations=0, + skipped=True, + skip_reason="Method not found", + ) + + method = getattr(articulation, method_name) + input_generator = ( + method_benchmark.input_generator_warp if mode == InputMode.WARP else method_benchmark.input_generator_torch + ) + + # Try to call the method once to check for errors + try: + inputs = input_generator(config) + method(**inputs) + except NotImplementedError as e: + return BenchmarkResult( + name=method_benchmark.name, + mode=mode, + mean_time_us=0.0, + std_time_us=0.0, + num_iterations=0, + skipped=True, + skip_reason=f"NotImplementedError: {e}", + ) + except Exception as e: + return BenchmarkResult( + name=method_benchmark.name, + mode=mode, + mean_time_us=0.0, + std_time_us=0.0, + num_iterations=0, + skipped=True, + skip_reason=f"Error: {type(e).__name__}: {e}", + ) + + # Warmup phase + for _ in range(config.warmup_steps): + inputs = input_generator(config) + with contextlib.suppress(Exception): + method(**inputs) + if config.device.startswith("cuda"): + wp.synchronize() + + # Timing phase + times = [] + for _ in range(config.num_iterations): + inputs = input_generator(config) + + # Sync before timing + if config.device.startswith("cuda"): + wp.synchronize() + + start_time = time.perf_counter() + try: + method(**inputs) + except Exception: + continue + + # Sync after to ensure kernel completion + if config.device.startswith("cuda"): + wp.synchronize() + + end_time = time.perf_counter() + times.append((end_time - start_time) * 1e6) # Convert to microseconds + + if not times: + return BenchmarkResult( + name=method_benchmark.name, + mode=mode, + mean_time_us=0.0, + std_time_us=0.0, + num_iterations=0, + skipped=True, + skip_reason="No successful iterations", + ) + + return BenchmarkResult( + name=method_benchmark.name, + mode=mode, + mean_time_us=float(np.mean(times)), + std_time_us=float(np.std(times)), + num_iterations=len(times), + ) + + +def run_benchmarks(config: BenchmarkConfig) -> tuple[list[BenchmarkResult], dict]: + """Run all benchmarks for Articulation. + + Args: + config: Benchmark configuration. + + Returns: + Tuple of (List of BenchmarkResults, hardware_info dict). + """ + results = [] + + # Gather and print hardware info + hardware_info = get_hardware_info() + print_hardware_info(hardware_info) + + # Create articulation + articulation, mock_view, _ = create_test_articulation( + num_instances=config.num_instances, + num_joints=config.num_joints, + num_bodies=config.num_bodies, + device=config.device, + ) + + # Determine modes to run + modes = [] + if config.mode in ("both", "warp"): + modes.append(InputMode.WARP) + if config.mode in ("both", "torch"): + modes.append(InputMode.TORCH) + + print(f"\nBenchmarking {len(METHOD_BENCHMARKS)} methods...") + print(f"Config: {config.num_iterations} iterations, {config.warmup_steps} warmup steps") + print(f" {config.num_instances} instances, {config.num_bodies} bodies, {config.num_joints} joints") + print(f"Modes: {', '.join(m.value for m in modes)}") + print("-" * 100) + + for i, method_benchmark in enumerate(METHOD_BENCHMARKS): + for mode in modes: + mode_str = f"[{mode.value.upper():5}]" + print(f"[{i + 1}/{len(METHOD_BENCHMARKS)}] {mode_str} {method_benchmark.name}...", end=" ", flush=True) + + result = benchmark_method(articulation, method_benchmark, mode, config) + results.append(result) + + if result.skipped: + print(f"SKIPPED ({result.skip_reason})") + else: + print(f"{result.mean_time_us:.2f} ± {result.std_time_us:.2f} µs") + + return results, hardware_info + + +def print_results(results: list[BenchmarkResult]): + """Print benchmark results in a formatted table.""" + print("\n" + "=" * 100) + print("BENCHMARK RESULTS") + print("=" * 100) + + # Separate by mode + warp_results = [r for r in results if r.mode == InputMode.WARP and not r.skipped] + torch_results = [r for r in results if r.mode == InputMode.TORCH and not r.skipped] + skipped = [r for r in results if r.skipped] + + # Print comparison table + if warp_results and torch_results: + print("\n" + "-" * 100) + print("COMPARISON: Warp (Best Case) vs Torch (Worst Case)") + print("-" * 100) + print(f"{'Method Name':<40} {'Warp (µs)':<15} {'Torch (µs)':<15} {'Overhead':<12} {'Slowdown':<10}") + print("-" * 100) + + warp_by_name = {r.name: r for r in warp_results} + torch_by_name = {r.name: r for r in torch_results} + + for name in warp_by_name: + if name in torch_by_name: + warp_time = warp_by_name[name].mean_time_us + torch_time = torch_by_name[name].mean_time_us + overhead = torch_time - warp_time + slowdown = torch_time / warp_time if warp_time > 0 else float("inf") + print(f"{name:<40} {warp_time:>12.2f} {torch_time:>12.2f} {overhead:>+9.2f} {slowdown:>7.2f}x") + + # Print individual results + for mode_name, mode_results in [("WARP (Best Case)", warp_results), ("TORCH (Worst Case)", torch_results)]: + if mode_results: + print("\n" + "-" * 100) + print(f"{mode_name}") + print("-" * 100) + + # Sort by mean time (descending) + mode_results_sorted = sorted(mode_results, key=lambda x: x.mean_time_us, reverse=True) + + print(f"{'Method Name':<45} {'Mean (µs)':<15} {'Std (µs)':<15} {'Iterations':<12}") + print("-" * 87) + + for result in mode_results_sorted: + print( + f"{result.name:<45} {result.mean_time_us:>12.2f} {result.std_time_us:>12.2f} " + f" {result.num_iterations:>10}" + ) + + # Summary stats + mean_times = [r.mean_time_us for r in mode_results_sorted] + print("-" * 87) + print(f" Fastest: {min(mean_times):.2f} µs ({mode_results_sorted[-1].name})") + print(f" Slowest: {max(mean_times):.2f} µs ({mode_results_sorted[0].name})") + print(f" Average: {np.mean(mean_times):.2f} µs") + + # Print skipped + if skipped: + print(f"\nSkipped Methods ({len(skipped)}):") + for result in skipped: + print(f" - {result.name} [{result.mode.value}]: {result.skip_reason}") + + +def export_results_json(results: list[BenchmarkResult], config: BenchmarkConfig, hardware_info: dict, filename: str): + """Export benchmark results to a JSON file.""" + import json + from datetime import datetime + + completed = [r for r in results if not r.skipped] + skipped = [r for r in results if r.skipped] + + git_info = get_git_info() + + output = { + "metadata": { + "timestamp": datetime.now().isoformat(), + "repository": git_info, + "config": { + "num_iterations": config.num_iterations, + "warmup_steps": config.warmup_steps, + "num_instances": config.num_instances, + "num_bodies": config.num_bodies, + "num_joints": config.num_joints, + "device": config.device, + "mode": config.mode, + }, + "hardware": hardware_info, + "total_benchmarks": len(results), + "completed_benchmarks": len(completed), + "skipped_benchmarks": len(skipped), + }, + "results": { + "warp": [], + "torch": [], + }, + "comparison": [], + "skipped": [], + } + + # Add individual results + for result in completed: + result_entry = { + "name": result.name, + "mean_time_us": result.mean_time_us, + "std_time_us": result.std_time_us, + "num_iterations": result.num_iterations, + } + if result.mode == InputMode.WARP: + output["results"]["warp"].append(result_entry) + else: + output["results"]["torch"].append(result_entry) + + # Add comparison data + warp_by_name = {r.name: r for r in completed if r.mode == InputMode.WARP} + torch_by_name = {r.name: r for r in completed if r.mode == InputMode.TORCH} + + for name in warp_by_name: + if name in torch_by_name: + warp_time = warp_by_name[name].mean_time_us + torch_time = torch_by_name[name].mean_time_us + output["comparison"].append({ + "name": name, + "warp_time_us": warp_time, + "torch_time_us": torch_time, + "overhead_us": torch_time - warp_time, + "slowdown_factor": torch_time / warp_time if warp_time > 0 else None, + }) + + # Add skipped + for result in skipped: + output["skipped"].append({ + "name": result.name, + "mode": result.mode.value, + "reason": result.skip_reason, + }) + + with open(filename, "w") as jsonfile: + json.dump(output, jsonfile, indent=2) + + print(f"\nResults exported to {filename}") + + +def get_default_output_filename() -> str: + """Generate default output filename with current date and time.""" + from datetime import datetime + + datetime_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + return f"articulation_benchmark_{datetime_str}.json" + + +def main(): + """Main entry point for the benchmarking script.""" + parser = argparse.ArgumentParser( + description="Micro-benchmarking framework for Articulation class.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--num_iterations", + type=int, + default=10000, + help="Number of iterations to run each method.", + ) + parser.add_argument( + "--warmup_steps", + type=int, + default=10, + help="Number of warmup steps before timing.", + ) + parser.add_argument( + "--num_instances", + type=int, + default=16384, + help="Number of articulation instances.", + ) + parser.add_argument( + "--num_bodies", + type=int, + default=12, + help="Number of bodies per articulation.", + ) + parser.add_argument( + "--num_joints", + type=int, + default=11, + help="Number of joints per articulation.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda:0", + help="Device to run benchmarks on.", + ) + parser.add_argument( + "--mode", + type=str, + choices=["warp", "torch", "both"], + default="both", + help="Benchmark mode: 'warp' (best case), 'torch' (worst case), or 'both'.", + ) + parser.add_argument( + "--output", + "-o", + type=str, + default=None, + help="Output JSON file for benchmark results.", + ) + parser.add_argument( + "--no_json", + action="store_true", + help="Disable JSON output.", + ) + + args = parser.parse_args() + + config = BenchmarkConfig( + num_iterations=args.num_iterations, + warmup_steps=args.warmup_steps, + num_instances=args.num_instances, + num_bodies=args.num_bodies, + num_joints=args.num_joints, + device=args.device, + mode=args.mode, + ) + + # Run benchmarks + results, hardware_info = run_benchmarks(config) + + # Print results + print_results(results) + + # Export to JSON + if not args.no_json: + output_filename = args.output if args.output else get_default_output_filename() + export_results_json(results, config, hardware_info, output_filename) + + +if __name__ == "__main__": + main() diff --git a/source/isaaclab_newton/test/assets/articulation/benchmark_articulation_data.py b/source/isaaclab_newton/test/assets/articulation/benchmark_articulation_data.py new file mode 100644 index 00000000000..f6a2de6ca46 --- /dev/null +++ b/source/isaaclab_newton/test/assets/articulation/benchmark_articulation_data.py @@ -0,0 +1,911 @@ +# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Micro-benchmarking framework for ArticulationData class. + +This module provides a benchmarking framework to measure the performance of all functions +in the ArticulationData class. Each function is run multiple times with randomized mock data, +and timing statistics (mean and standard deviation) are reported. + +Usage: + python benchmark_articulation_data.py [--num_iterations N] [--warmup_steps W] [--num_instances I] [--num_bodies B] [--num_joints J] + +Example: + python benchmark_articulation_data.py --num_iterations 10000 --warmup_steps 10 +""" + +from __future__ import annotations + +import argparse +import contextlib +import numpy as np +import time +import torch +import warnings +from dataclasses import dataclass +from unittest.mock import MagicMock, patch + +import warp as wp +from isaaclab_newton.assets.articulation.articulation_data import ArticulationData + +# Import mock classes from shared module +from mock_interface import MockNewtonArticulationView, MockNewtonModel + +# Initialize Warp +wp.init() + +# Suppress deprecation warnings during benchmarking +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=UserWarning) + + +def get_git_info() -> dict: + """Get git repository information. + + Returns: + Dictionary containing git commit hash, branch, and other info. + """ + import os + import subprocess + + git_info = { + "commit_hash": "Unknown", + "commit_hash_short": "Unknown", + "branch": "Unknown", + "commit_date": "Unknown", + } + + # Get the directory of this file to find the repo root + script_dir = os.path.dirname(os.path.abspath(__file__)) + + try: + # Get full commit hash + result = subprocess.run( + ["git", "rev-parse", "HEAD"], + cwd=script_dir, + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + git_info["commit_hash"] = result.stdout.strip() + git_info["commit_hash_short"] = result.stdout.strip()[:8] + + # Get branch name + result = subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + cwd=script_dir, + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + git_info["branch"] = result.stdout.strip() + + # Get commit date + result = subprocess.run( + ["git", "log", "-1", "--format=%ci"], + cwd=script_dir, + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + git_info["commit_date"] = result.stdout.strip() + + except Exception: + pass + + return git_info + + +def get_hardware_info() -> dict: + """Gather hardware information for the benchmark. + + Returns: + Dictionary containing CPU, GPU, and memory information. + """ + import os + import platform + + hardware_info = { + "cpu": { + "name": platform.processor() or "Unknown", + "physical_cores": os.cpu_count(), + }, + "gpu": {}, + "memory": {}, + "system": { + "platform": platform.system(), + "platform_release": platform.release(), + "platform_version": platform.version(), + "architecture": platform.machine(), + "python_version": platform.python_version(), + }, + } + + # Try to get more detailed CPU info on Linux + with contextlib.suppress(Exception): + with open("/proc/cpuinfo") as f: + cpuinfo = f.read() + for line in cpuinfo.split("\n"): + if "model name" in line: + hardware_info["cpu"]["name"] = line.split(":")[1].strip() + break + + # Memory info + try: + with open("/proc/meminfo") as f: + meminfo = f.read() + for line in meminfo.split("\n"): + if "MemTotal" in line: + # Convert from KB to GB + mem_kb = int(line.split()[1]) + hardware_info["memory"]["total_gb"] = round(mem_kb / (1024 * 1024), 2) + break + except Exception: + # Fallback using psutil if available + try: + import psutil + + mem = psutil.virtual_memory() + hardware_info["memory"]["total_gb"] = round(mem.total / (1024**3), 2) + except ImportError: + hardware_info["memory"]["total_gb"] = "Unknown" + + # GPU info using PyTorch + if torch.cuda.is_available(): + hardware_info["gpu"]["available"] = True + hardware_info["gpu"]["count"] = torch.cuda.device_count() + hardware_info["gpu"]["devices"] = [] + + for i in range(torch.cuda.device_count()): + gpu_props = torch.cuda.get_device_properties(i) + hardware_info["gpu"]["devices"].append({ + "index": i, + "name": gpu_props.name, + "total_memory_gb": round(gpu_props.total_memory / (1024**3), 2), + "compute_capability": f"{gpu_props.major}.{gpu_props.minor}", + "multi_processor_count": gpu_props.multi_processor_count, + }) + + # Current device info + current_device = torch.cuda.current_device() + hardware_info["gpu"]["current_device"] = current_device + hardware_info["gpu"]["current_device_name"] = torch.cuda.get_device_name(current_device) + else: + hardware_info["gpu"]["available"] = False + + # PyTorch and CUDA versions + hardware_info["gpu"]["pytorch_version"] = torch.__version__ + if torch.cuda.is_available(): + try: + import torch.version as torch_version + + cuda_version = getattr(torch_version, "cuda", None) + hardware_info["gpu"]["cuda_version"] = cuda_version if cuda_version else "Unknown" + except Exception: + hardware_info["gpu"]["cuda_version"] = "Unknown" + else: + hardware_info["gpu"]["cuda_version"] = "N/A" + + # Warp info + try: + warp_version = getattr(wp.config, "version", None) + hardware_info["warp"] = {"version": warp_version if warp_version else "Unknown"} + except Exception: + hardware_info["warp"] = {"version": "Unknown"} + + return hardware_info + + +def print_hardware_info(hardware_info: dict): + """Print hardware information to console. + + Args: + hardware_info: Dictionary containing hardware information. + """ + print("\n" + "=" * 80) + print("HARDWARE INFORMATION") + print("=" * 80) + + # System + sys_info = hardware_info.get("system", {}) + print(f"\nSystem: {sys_info.get('platform', 'Unknown')} {sys_info.get('platform_release', '')}") + print(f"Python: {sys_info.get('python_version', 'Unknown')}") + + # CPU + cpu_info = hardware_info.get("cpu", {}) + print(f"\nCPU: {cpu_info.get('name', 'Unknown')}") + print(f" Cores: {cpu_info.get('physical_cores', 'Unknown')}") + + # Memory + mem_info = hardware_info.get("memory", {}) + print(f"\nMemory: {mem_info.get('total_gb', 'Unknown')} GB") + + # GPU + gpu_info = hardware_info.get("gpu", {}) + if gpu_info.get("available", False): + print(f"\nGPU: {gpu_info.get('current_device_name', 'Unknown')}") + for device in gpu_info.get("devices", []): + print(f" [{device['index']}] {device['name']}") + print(f" Memory: {device['total_memory_gb']} GB") + print(f" Compute Capability: {device['compute_capability']}") + print(f" SM Count: {device['multi_processor_count']}") + print(f"\nPyTorch: {gpu_info.get('pytorch_version', 'Unknown')}") + print(f"CUDA: {gpu_info.get('cuda_version', 'Unknown')}") + else: + print("\nGPU: Not available") + + warp_info = hardware_info.get("warp", {}) + print(f"Warp: {warp_info.get('version', 'Unknown')}") + + # Repository info (get separately since it's not part of hardware) + repo_info = get_git_info() + print("\nRepository:") + print(f" Commit: {repo_info.get('commit_hash_short', 'Unknown')}") + print(f" Branch: {repo_info.get('branch', 'Unknown')}") + print(f" Date: {repo_info.get('commit_date', 'Unknown')}") + print("=" * 80) + + +@dataclass +class BenchmarkConfig: + """Configuration for the benchmarking framework.""" + + num_iterations: int = 10000 + """Number of iterations to run each function.""" + + warmup_steps: int = 10 + """Number of warmup steps before timing.""" + + num_instances: int = 16384 + """Number of articulation instances.""" + + num_bodies: int = 12 + """Number of bodies per articulation.""" + + num_joints: int = 11 + """Number of joints per articulation.""" + + device: str = "cuda:0" + """Device to run benchmarks on.""" + + +@dataclass +class BenchmarkResult: + """Result of a single benchmark.""" + + name: str + """Name of the benchmarked function/property.""" + + mean_time_us: float + """Mean execution time in microseconds.""" + + std_time_us: float + """Standard deviation of execution time in microseconds.""" + + num_iterations: int + """Number of iterations run.""" + + skipped: bool = False + """Whether the benchmark was skipped.""" + + skip_reason: str = "" + """Reason for skipping the benchmark.""" + + dependencies: list[str] | None = None + """List of parent properties that were pre-computed before timing.""" + + +# List of deprecated properties (for backward compatibility) - skip these +DEPRECATED_PROPERTIES = { + "default_root_state", + "root_pose_w", + "root_pos_w", + "root_quat_w", + "root_vel_w", + "root_lin_vel_w", + "root_ang_vel_w", + "root_lin_vel_b", + "root_ang_vel_b", + "body_pose_w", + "body_pos_w", + "body_quat_w", + "body_vel_w", + "body_lin_vel_w", + "body_ang_vel_w", + "body_acc_w", + "body_lin_acc_w", + "body_ang_acc_w", + "com_pos_b", + "com_quat_b", + "joint_limits", + "joint_friction", + "fixed_tendon_limit", + "applied_torque", + "computed_torque", + "joint_dynamic_friction", + "joint_effort_target", + "joint_viscous_friction", + "joint_velocity_limits", + # Also skip the combined state properties marked as deprecated + "root_state_w", + "root_link_state_w", + "root_com_state_w", + "body_state_w", + "body_link_state_w", + "body_com_state_w", +} + +# List of properties that raise NotImplementedError - skip these +NOT_IMPLEMENTED_PROPERTIES = { + "fixed_tendon_stiffness", + "fixed_tendon_damping", + "fixed_tendon_limit_stiffness", + "fixed_tendon_rest_length", + "fixed_tendon_offset", + "fixed_tendon_pos_limits", + "spatial_tendon_stiffness", + "spatial_tendon_damping", + "spatial_tendon_limit_stiffness", + "spatial_tendon_offset", + "body_incoming_joint_wrench_b", +} + +# Private/internal properties and methods to skip +INTERNAL_PROPERTIES = { + "_create_simulation_bindings", + "_create_buffers", + "update", + "is_primed", + "device", + "body_names", + "joint_names", + "fixed_tendon_names", + "spatial_tendon_names", + "GRAVITY_VEC_W", + "GRAVITY_VEC_W_TORCH", + "FORWARD_VEC_B", + "FORWARD_VEC_B_TORCH", + "ALL_ENV_MASK", + "ALL_BODY_MASK", + "ALL_JOINT_MASK", + "ENV_MASK", + "BODY_MASK", + "JOINT_MASK", +} + +# Dependency mapping: derived properties and their parent dependencies. +# Before benchmarking a derived property, we first call the parent to populate +# its cache, so we measure only the overhead of the derived property extraction. +PROPERTY_DEPENDENCIES = { + # Root link velocity slices (depend on root_link_vel_w) + "root_link_lin_vel_w": ["root_link_vel_w"], + "root_link_ang_vel_w": ["root_link_vel_w"], + # Root link velocity in body frame slices (depend on root_link_vel_b) + "root_link_lin_vel_b": ["root_link_vel_b"], + "root_link_ang_vel_b": ["root_link_vel_b"], + # Root COM pose slices (depend on root_com_pose_w) + "root_com_pos_w": ["root_com_pose_w"], + "root_com_quat_w": ["root_com_pose_w"], + # Root COM velocity slices (depend on root_com_vel_b) + "root_com_lin_vel_b": ["root_com_vel_b"], + "root_com_ang_vel_b": ["root_com_vel_b"], + # Root COM velocity in world frame slices (no lazy dependency, direct binding) + "root_com_lin_vel_w": ["root_com_vel_w"], + "root_com_ang_vel_w": ["root_com_vel_w"], + # Root link pose slices (no lazy dependency, direct binding) + "root_link_pos_w": ["root_link_pose_w"], + "root_link_quat_w": ["root_link_pose_w"], + # Body link velocity slices (depend on body_link_vel_w) + "body_link_lin_vel_w": ["body_link_vel_w"], + "body_link_ang_vel_w": ["body_link_vel_w"], + # Body link pose slices (no lazy dependency, direct binding) + "body_link_pos_w": ["body_link_pose_w"], + "body_link_quat_w": ["body_link_pose_w"], + # Body COM pose slices (depend on body_com_pose_w) + "body_com_pos_w": ["body_com_pose_w"], + "body_com_quat_w": ["body_com_pose_w"], + # Body COM velocity slices (no lazy dependency, direct binding) + "body_com_lin_vel_w": ["body_com_vel_w"], + "body_com_ang_vel_w": ["body_com_vel_w"], + # Body COM acceleration slices (depend on body_com_acc_w) + "body_com_lin_acc_w": ["body_com_acc_w"], + "body_com_ang_acc_w": ["body_com_acc_w"], + # Body COM pose/quat in body frame (depend on body_com_pose_b) + "body_com_quat_b": ["body_com_pose_b"], +} + + +def get_benchmarkable_properties(articulation_data: ArticulationData) -> list[str]: + """Get list of properties that can be benchmarked. + + Args: + articulation_data: The ArticulationData instance to inspect. + + Returns: + List of property names that can be benchmarked. + """ + all_properties = [] + + # Get all properties from the class + for name in dir(articulation_data): + # Skip private/dunder methods + if name.startswith("_"): + continue + + # Skip deprecated properties + if name in DEPRECATED_PROPERTIES: + continue + + # Skip not implemented properties + if name in NOT_IMPLEMENTED_PROPERTIES: + continue + + # Skip internal properties + if name in INTERNAL_PROPERTIES: + continue + + # Check if it's a property (not a method that needs arguments) + try: + attr = getattr(type(articulation_data), name, None) + if isinstance(attr, property): + all_properties.append(name) + except Exception: + pass + + return sorted(all_properties) + + +def setup_mock_environment( + config: BenchmarkConfig, +) -> tuple[MockNewtonArticulationView, MockNewtonModel, MagicMock, MagicMock]: + """Set up the mock environment for benchmarking. + + Args: + config: Benchmark configuration. + + Returns: + Tuple of (mock_view, mock_model, mock_state, mock_control). + """ + # Create mock Newton model + mock_model = MockNewtonModel() + mock_state = MagicMock() + mock_control = MagicMock() + + # Create mock view + mock_view = MockNewtonArticulationView( + num_instances=config.num_instances, + num_bodies=config.num_bodies, + num_joints=config.num_joints, + device=config.device, + ) + + return mock_view, mock_model, mock_state, mock_control + + +def benchmark_property( + articulation_data: ArticulationData, + mock_view: MockNewtonArticulationView, + property_name: str, + config: BenchmarkConfig, +) -> BenchmarkResult: + """Benchmark a single property of ArticulationData. + + Args: + articulation_data: The ArticulationData instance. + mock_view: The mock view for setting random data. + property_name: Name of the property to benchmark. + config: Benchmark configuration. + + Returns: + BenchmarkResult with timing statistics. + """ + # Check if property exists + if not hasattr(articulation_data, property_name): + return BenchmarkResult( + name=property_name, + mean_time_us=0.0, + std_time_us=0.0, + num_iterations=0, + skipped=True, + skip_reason="Property not found", + ) + + # Try to access the property once to check if it raises NotImplementedError + try: + _ = getattr(articulation_data, property_name) + except NotImplementedError as e: + return BenchmarkResult( + name=property_name, + mean_time_us=0.0, + std_time_us=0.0, + num_iterations=0, + skipped=True, + skip_reason=f"NotImplementedError: {e}", + ) + except Exception as e: + return BenchmarkResult( + name=property_name, + mean_time_us=0.0, + std_time_us=0.0, + num_iterations=0, + skipped=True, + skip_reason=f"Error: {type(e).__name__}: {e}", + ) + + # Get dependencies for this property (if any) + dependencies = PROPERTY_DEPENDENCIES.get(property_name, []) + + # Warmup phase with random data + for _ in range(config.warmup_steps): + mock_view.set_random_mock_data() + articulation_data._sim_timestamp += 1.0 # Invalidate cached data + try: + # Warm up dependencies first + for dep in dependencies: + _ = getattr(articulation_data, dep) + # Then warm up the target property + _ = getattr(articulation_data, property_name) + except Exception: + pass + # Sync GPU + if config.device.startswith("cuda"): + wp.synchronize() + + # Timing phase + times = [] + for _ in range(config.num_iterations): + # Randomize mock data each iteration + mock_view.set_random_mock_data() + articulation_data._sim_timestamp += 1.0 # Invalidate cached data + + # Call dependencies first to populate their caches (not timed) + # This ensures we only measure the overhead of the derived property + with contextlib.suppress(Exception): + for dep in dependencies: + _ = getattr(articulation_data, dep) + + # Sync before timing + if config.device.startswith("cuda"): + wp.synchronize() + + # Time only the target property access + start_time = time.perf_counter() + try: + _ = getattr(articulation_data, property_name) + except Exception: + continue + + # Sync after to ensure kernel completion + if config.device.startswith("cuda"): + wp.synchronize() + + end_time = time.perf_counter() + times.append((end_time - start_time) * 1e6) # Convert to microseconds + + if not times: + return BenchmarkResult( + name=property_name, + mean_time_us=0.0, + std_time_us=0.0, + num_iterations=0, + skipped=True, + skip_reason="No successful iterations", + ) + + return BenchmarkResult( + name=property_name, + mean_time_us=float(np.mean(times)), + std_time_us=float(np.std(times)), + num_iterations=len(times), + dependencies=dependencies if dependencies else None, + ) + + +def run_benchmarks(config: BenchmarkConfig) -> tuple[list[BenchmarkResult], dict]: + """Run all benchmarks for ArticulationData. + + Args: + config: Benchmark configuration. + + Returns: + Tuple of (List of BenchmarkResults, hardware_info dict). + """ + results = [] + + # Gather and print hardware info + hardware_info = get_hardware_info() + print_hardware_info(hardware_info) + + # Setup mock environment + mock_view, mock_model, mock_state, mock_control = setup_mock_environment(config) + + # Patch NewtonManager + with patch("isaaclab_newton.assets.articulation.articulation_data.NewtonManager") as MockManager: + MockManager.get_model.return_value = mock_model + MockManager.get_state_0.return_value = mock_state + MockManager.get_control.return_value = mock_control + MockManager.get_dt.return_value = 0.01 + + # Initialize mock data + mock_view.set_random_mock_data() + + # Create ArticulationData instance + articulation_data = ArticulationData(mock_view, config.device) + + # Get list of properties to benchmark + properties = get_benchmarkable_properties(articulation_data) + + print(f"\nBenchmarking {len(properties)} properties...") + print(f"Config: {config.num_iterations} iterations, {config.warmup_steps} warmup steps") + print(f" {config.num_instances} instances, {config.num_bodies} bodies, {config.num_joints} joints") + print("-" * 80) + + for i, prop_name in enumerate(properties): + print(f"[{i + 1}/{len(properties)}] Benchmarking {prop_name}...", end=" ", flush=True) + + result = benchmark_property(articulation_data, mock_view, prop_name, config) + results.append(result) + + if result.skipped: + print(f"SKIPPED ({result.skip_reason})") + else: + print(f"{result.mean_time_us:.2f} ± {result.std_time_us:.2f} µs") + + return results, hardware_info + + +def print_results(results: list[BenchmarkResult]): + """Print benchmark results in a formatted table. + + Args: + results: List of BenchmarkResults to print. + """ + print("\n" + "=" * 80) + print("BENCHMARK RESULTS") + print("=" * 80) + + # Separate skipped and completed results + completed = [r for r in results if not r.skipped] + skipped = [r for r in results if r.skipped] + + # Sort completed by mean time (descending) + completed.sort(key=lambda x: x.mean_time_us, reverse=True) + + # Print header + print(f"\n{'Property Name':<45} {'Mean (µs)':<15} {'Std (µs)':<15} {'Iterations':<12}") + print("-" * 87) + + # Print completed results + for result in completed: + # Add marker for properties with pre-computed dependencies + name_display = result.name + if result.dependencies: + name_display = f"{result.name} *" + print( + f"{name_display:<45} {result.mean_time_us:>12.2f} {result.std_time_us:>12.2f} " + f" {result.num_iterations:>10}" + ) + + # Print summary statistics + if completed: + print("-" * 87) + mean_times = [r.mean_time_us for r in completed] + print("\nSummary Statistics:") + print(f" Total properties benchmarked: {len(completed)}") + print(f" Fastest: {min(mean_times):.2f} µs ({completed[-1].name})") + print(f" Slowest: {max(mean_times):.2f} µs ({completed[0].name})") + print(f" Average: {np.mean(mean_times):.2f} µs") + print(f" Median: {np.median(mean_times):.2f} µs") + + # Print note about derived properties + derived_count = sum(1 for r in completed if r.dependencies) + if derived_count > 0: + print(f"\n * = Derived property ({derived_count} total). Dependencies were pre-computed") + print(" before timing to measure isolated overhead.") + + # Print skipped results + if skipped: + print(f"\nSkipped Properties ({len(skipped)}):") + for result in skipped: + print(f" - {result.name}: {result.skip_reason}") + + +def export_results_csv(results: list[BenchmarkResult], filename: str): + """Export benchmark results to a CSV file. + + Args: + results: List of BenchmarkResults to export. + filename: Output CSV filename. + """ + import csv + + with open(filename, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + writer.writerow(["Property", "Mean (µs)", "Std (µs)", "Iterations", "Dependencies", "Skipped", "Skip Reason"]) + + for result in results: + deps_str = ", ".join(result.dependencies) if result.dependencies else "" + writer.writerow([ + result.name, + f"{result.mean_time_us:.4f}" if not result.skipped else "", + f"{result.std_time_us:.4f}" if not result.skipped else "", + result.num_iterations, + deps_str, + result.skipped, + result.skip_reason, + ]) + + print(f"\nResults exported to {filename}") + + +def export_results_json(results: list[BenchmarkResult], config: BenchmarkConfig, hardware_info: dict, filename: str): + """Export benchmark results to a JSON file. + + Args: + results: List of BenchmarkResults to export. + config: Benchmark configuration used. + hardware_info: Hardware information dictionary. + filename: Output JSON filename. + """ + import json + from datetime import datetime + + # Separate completed and skipped results + completed = [r for r in results if not r.skipped] + skipped = [r for r in results if r.skipped] + + # Get git repository info + git_info = get_git_info() + + # Build the JSON structure + output = { + "metadata": { + "timestamp": datetime.now().isoformat(), + "repository": git_info, + "config": { + "num_iterations": config.num_iterations, + "warmup_steps": config.warmup_steps, + "num_instances": config.num_instances, + "num_bodies": config.num_bodies, + "num_joints": config.num_joints, + "device": config.device, + }, + "hardware": hardware_info, + "total_properties": len(results), + "benchmarked_properties": len(completed), + "skipped_properties": len(skipped), + }, + "results": [], + "skipped": [], + } + + # Add individual results + for result in completed: + result_entry = { + "name": result.name, + "mean_time_us": result.mean_time_us, + "std_time_us": result.std_time_us, + "num_iterations": result.num_iterations, + } + if result.dependencies: + result_entry["dependencies"] = result.dependencies + result_entry["note"] = "Timing excludes dependency computation (dependencies were pre-computed)" + output["results"].append(result_entry) + + # Add skipped properties + for result in skipped: + output["skipped"].append({ + "name": result.name, + "reason": result.skip_reason, + }) + + # Write JSON file + with open(filename, "w") as jsonfile: + json.dump(output, jsonfile, indent=2) + + print(f"Results exported to {filename}") + + +def get_default_output_filename() -> str: + """Generate default output filename with current date and time.""" + from datetime import datetime + + datetime_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + return f"articulation_data_{datetime_str}.json" + + +def main(): + """Main entry point for the benchmarking script.""" + parser = argparse.ArgumentParser( + description="Micro-benchmarking framework for ArticulationData class.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--num_iterations", + type=int, + default=10000, + help="Number of iterations to run each function.", + ) + parser.add_argument( + "--warmup_steps", + type=int, + default=10, + help="Number of warmup steps before timing.", + ) + parser.add_argument( + "--num_instances", + type=int, + default=16384, + help="Number of articulation instances.", + ) + parser.add_argument( + "--num_bodies", + type=int, + default=12, + help="Number of bodies per articulation.", + ) + parser.add_argument( + "--num_joints", + type=int, + default=11, + help="Number of joints per articulation.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda:0", + help="Device to run benchmarks on.", + ) + parser.add_argument( + "--output", + "-o", + type=str, + default=None, + help="Output JSON file for benchmark results. Default: articulation_data_DATE.json", + ) + parser.add_argument( + "--export_csv", + type=str, + default=None, + help="Additionally export results to CSV file.", + ) + parser.add_argument( + "--no_json", + action="store_true", + help="Disable JSON output.", + ) + + args = parser.parse_args() + + config = BenchmarkConfig( + num_iterations=args.num_iterations, + warmup_steps=args.warmup_steps, + num_instances=args.num_instances, + num_bodies=args.num_bodies, + num_joints=args.num_joints, + device=args.device, + ) + + # Run benchmarks + results, hardware_info = run_benchmarks(config) + + # Print results + print_results(results) + + # Export to JSON (default) + if not args.no_json: + output_filename = args.output if args.output else get_default_output_filename() + export_results_json(results, config, hardware_info, output_filename) + + # Export to CSV if requested + if args.export_csv: + export_results_csv(results, args.export_csv) + + +if __name__ == "__main__": + main() diff --git a/source/isaaclab_newton/test/assets/articulation/mock_interface.py b/source/isaaclab_newton/test/assets/articulation/mock_interface.py new file mode 100644 index 00000000000..78be20ccd30 --- /dev/null +++ b/source/isaaclab_newton/test/assets/articulation/mock_interface.py @@ -0,0 +1,551 @@ +# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Mock interfaces for testing and benchmarking ArticulationData class.""" + +from __future__ import annotations + +import torch +from unittest.mock import MagicMock, patch + +import warp as wp + +## +# Mock classes for Newton +## + + +class MockNewtonModel: + """Mock Newton model that provides gravity.""" + + def __init__(self, gravity: tuple[float, float, float] = (0.0, 0.0, -9.81)): + self._gravity = wp.array([gravity], dtype=wp.vec3f, device="cuda:0") + + @property + def gravity(self): + return self._gravity + + +class MockNewtonArticulationView: + """Mock NewtonArticulationView that provides simulation bindings. + + This class mimics the interface that ArticulationData expects from Newton. + """ + + def __init__( + self, + num_instances: int, + num_bodies: int, + num_joints: int, + device: str = "cuda:0", + is_fixed_base: bool = False, + joint_names: list[str] | None = None, + body_names: list[str] | None = None, + ): + """Initialize the mock NewtonArticulationView. + + Args: + num_instances: Number of instances. + num_bodies: Number of bodies. + num_joints: Number of joints. + device: Device to use. + is_fixed_base: Whether the articulation is fixed-base. + joint_names: Names of joints. Defaults to ["joint_0", ...]. + body_names: Names of bodies. Defaults to ["body_0", ...]. + """ + # Set the parameters + self._count = num_instances + self._link_count = num_bodies + self._joint_dof_count = num_joints + self._device = device + self._is_fixed_base = is_fixed_base + + # Set joint and body names + if joint_names is None: + self._joint_dof_names = [f"joint_{i}" for i in range(num_joints)] + else: + self._joint_dof_names = joint_names + + if body_names is None: + self._body_names = [f"body_{i}" for i in range(num_bodies)] + else: + self._body_names = body_names + + # Storage for mock data + # Note: These are set via set_mock_data() before any property access in tests + self._root_transforms = wp.zeros((num_instances,), dtype=wp.transformf, device=device) + self._root_velocities = wp.zeros((num_instances,), dtype=wp.spatial_vectorf, device=device) + self._link_transforms = wp.zeros((num_instances, num_bodies), dtype=wp.transformf, device=device) + self._link_velocities = wp.zeros((num_instances, num_bodies), dtype=wp.spatial_vectorf, device=device) + self._dof_positions = wp.zeros((num_instances, num_joints), dtype=wp.float32, device=device) + self._dof_velocities = wp.zeros((num_instances, num_joints), dtype=wp.float32, device=device) + + # Initialize default attributes + self._attributes: dict = {} + self._attributes["body_com"] = wp.zeros((self._count, self._link_count), dtype=wp.vec3f, device=self._device) + self._attributes["body_mass"] = wp.zeros((self._count, self._link_count), dtype=wp.float32, device=self._device) + self._attributes["body_inertia"] = wp.zeros( + (self._count, self._link_count), dtype=wp.mat33f, device=self._device + ) + self._attributes["joint_limit_lower"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["joint_limit_upper"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["joint_target_ke"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["joint_target_kd"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["joint_armature"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["joint_friction"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["joint_velocity_limit"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["joint_effort_limit"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["body_f"] = wp.zeros( + (self._count, self._link_count), dtype=wp.spatial_vectorf, device=self._device + ) + self._attributes["joint_f"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["joint_target_pos"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["joint_target_vel"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["joint_limit_ke"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["joint_limit_kd"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + + @property + def count(self) -> int: + return self._count + + @property + def link_count(self) -> int: + return self._link_count + + @property + def joint_dof_count(self) -> int: + return self._joint_dof_count + + @property + def is_fixed_base(self) -> bool: + return self._is_fixed_base + + @property + def joint_dof_names(self) -> list[str]: + return self._joint_dof_names + + @property + def body_names(self) -> list[str]: + return self._body_names + + def get_root_transforms(self, state) -> wp.array: + return self._root_transforms + + def get_root_velocities(self, state) -> wp.array: + return self._root_velocities + + def get_link_transforms(self, state) -> wp.array: + return self._link_transforms + + def get_link_velocities(self, state) -> wp.array: + return self._link_velocities + + def get_dof_positions(self, state) -> wp.array: + return self._dof_positions + + def get_dof_velocities(self, state) -> wp.array: + return self._dof_velocities + + def get_attribute(self, name: str, model_or_state) -> wp.array: + return self._attributes[name] + + def set_root_transforms(self, state, transforms: wp.array): + print(f"Setting root transforms: {transforms}") + print(f"Root transforms: {self._root_transforms}") + self._root_transforms.assign(transforms) + + def set_root_velocities(self, state, velocities: wp.array): + self._root_velocities.assign(velocities) + + def set_mock_data( + self, + root_transforms: wp.array | None = None, + root_velocities: wp.array | None = None, + link_transforms: wp.array | None = None, + link_velocities: wp.array | None = None, + body_com_pos: wp.array | None = None, + dof_positions: wp.array | None = None, + dof_velocities: wp.array | None = None, + body_mass: wp.array | None = None, + body_inertia: wp.array | None = None, + joint_limit_lower: wp.array | None = None, + joint_limit_upper: wp.array | None = None, + ): + """Set mock simulation data.""" + if root_transforms is None: + self._root_transforms.assign(wp.zeros((self._count,), dtype=wp.transformf, device=self._device)) + else: + self._root_transforms.assign(root_transforms) + if root_velocities is None: + self._root_velocities.assign(wp.zeros((self._count,), dtype=wp.spatial_vectorf, device=self._device)) + else: + self._root_velocities.assign(root_velocities) + if link_transforms is None: + self._link_transforms.assign( + wp.zeros((self._count, self._link_count), dtype=wp.transformf, device=self._device) + ) + else: + self._link_transforms.assign(link_transforms) + if link_velocities is None: + self._link_velocities.assign( + wp.zeros((self._count, self._link_count), dtype=wp.spatial_vectorf, device=self._device) + ) + else: + self._link_velocities.assign(link_velocities) + + # Set attributes that ArticulationData expects + if body_com_pos is None: + self._attributes["body_com"].assign( + wp.zeros((self._count, self._link_count), dtype=wp.vec3f, device=self._device) + ) + else: + self._attributes["body_com"].assign(body_com_pos) + + if dof_positions is None: + self._dof_positions.assign( + wp.zeros((self._count, self._joint_dof_count), dtype=wp.float32, device=self._device) + ) + else: + self._dof_positions.assign(dof_positions) + + if dof_velocities is None: + self._dof_velocities.assign( + wp.zeros((self._count, self._joint_dof_count), dtype=wp.float32, device=self._device) + ) + else: + self._dof_velocities.assign(dof_velocities) + + if body_mass is None: + self._attributes["body_mass"].assign( + wp.zeros((self._count, self._link_count), dtype=wp.float32, device=self._device) + ) + else: + self._attributes["body_mass"].assign(body_mass) + + if body_inertia is None: + # Initialize as identity inertia tensors + self._attributes["body_inertia"].assign( + wp.zeros((self._count, self._link_count), dtype=wp.mat33f, device=self._device) + ) + else: + self._attributes["body_inertia"].assign(body_inertia) + + if joint_limit_lower is not None: + self._attributes["joint_limit_lower"].assign(joint_limit_lower) + + if joint_limit_upper is not None: + self._attributes["joint_limit_upper"].assign(joint_limit_upper) + + def set_random_mock_data(self): + """Set randomized mock simulation data for benchmarking.""" + # Generate random root transforms (position + normalized quaternion) + root_pose = torch.zeros((self._count, 7), device=self._device) + root_pose[:, :3] = torch.rand((self._count, 3), device=self._device) * 10.0 - 5.0 # Random positions + root_pose[:, 3:] = torch.randn((self._count, 4), device=self._device) + root_pose[:, 3:] = torch.nn.functional.normalize(root_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + + # Generate random velocities + root_vel = torch.rand((self._count, 6), device=self._device) * 2.0 - 1.0 + + # Generate random link transforms + link_pose = torch.zeros((self._count, self._link_count, 7), device=self._device) + link_pose[:, :, :3] = torch.rand((self._count, self._link_count, 3), device=self._device) * 10.0 - 5.0 + link_pose[:, :, 3:] = torch.randn((self._count, self._link_count, 4), device=self._device) + link_pose[:, :, 3:] = torch.nn.functional.normalize(link_pose[:, :, 3:], p=2.0, dim=-1, eps=1e-12) + + # Generate random link velocities + link_vel = torch.rand((self._count, self._link_count, 6), device=self._device) * 2.0 - 1.0 + + # Generate random body COM positions + body_com_pos = torch.rand((self._count, self._link_count, 3), device=self._device) * 0.2 - 0.1 + + # Generate random joint positions and velocities + dof_pos = torch.rand((self._count, self._joint_dof_count), device=self._device) * 6.28 - 3.14 + dof_vel = torch.rand((self._count, self._joint_dof_count), device=self._device) * 2.0 - 1.0 + + # Generate random body masses (positive values) + body_mass = torch.rand((self._count, self._link_count), device=self._device) * 10.0 + 0.1 + + # Set the mock data + self.set_mock_data( + root_transforms=wp.from_torch(root_pose, dtype=wp.transformf), + root_velocities=wp.from_torch(root_vel, dtype=wp.spatial_vectorf), + link_transforms=wp.from_torch(link_pose, dtype=wp.transformf), + link_velocities=wp.from_torch(link_vel, dtype=wp.spatial_vectorf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + dof_positions=wp.from_torch(dof_pos, dtype=wp.float32), + dof_velocities=wp.from_torch(dof_vel, dtype=wp.float32), + body_mass=wp.from_torch(body_mass, dtype=wp.float32), + ) + + +class MockSharedMetaDataType: + """Mock shared meta data types.""" + + def __init__(self, fixed_base: bool, dof_count: int, link_count: int, dof_names: list[str], link_names: list[str]): + self._fixed_base: bool = fixed_base + self._dof_count: int = dof_count + self._link_count: int = link_count + self._dof_names: list[str] = dof_names + self._link_names: list[str] = link_names + + @property + def fixed_base(self) -> bool: + return self._fixed_base + + @property + def dof_count(self) -> int: + return self._dof_count + + @property + def link_count(self) -> int: + return self._link_count + + @property + def dof_names(self) -> list[str]: + return self._dof_names + + @property + def link_names(self) -> list[str]: + return self._link_names + + +class MockArticulationTensorAPI: + """Mock ArticulationView that provides tensor API like interface. + + This is for testing against the PhysX implementation which uses torch.Tensor. + """ + + def __init__( + self, + num_instances: int, + num_bodies: int, + num_joints: int, + device: str, + fixed_base: bool = False, + dof_names: list[str] = [], + link_names: list[str] = [], + ): + """Initialize the mock ArticulationTensorAPI. + + Args: + num_instances: Number of instances. + num_bodies: Number of bodies. + num_joints: Number of joints. + device: Device to use. + fixed_base: Whether the articulation is a fixed-base or floating-base system. (default: False) + dof_names: Names of the joints. (default: []) + link_names: Names of the bodies. (default: []) + """ + # Set the parameters + self._count = num_instances + self._link_count = num_bodies + self._joint_dof_count = num_joints + self._device = device + + # Mock shared meta data type + if not dof_names: + dof_names = [f"dof_{i}" for i in range(num_joints)] + else: + assert len(dof_names) == num_joints, "The number of dof names must be equal to the number of joints." + if not link_names: + link_names = [f"link_{i}" for i in range(num_bodies)] + else: + assert len(link_names) == num_bodies, "The number of link names must be equal to the number of bodies." + self._shared_metatype = MockSharedMetaDataType(fixed_base, num_joints, num_bodies, dof_names, link_names) + + # Storage for mock data + # Note: These are set via set_mock_data() before any property access in tests + self._root_transforms: torch.Tensor + self._root_velocities: torch.Tensor + self._link_transforms: torch.Tensor + self._link_velocities: torch.Tensor + self._link_acceleration: torch.Tensor + self._body_com: torch.Tensor + self._dof_positions: torch.Tensor + self._dof_velocities: torch.Tensor + self._body_mass: torch.Tensor + self._body_inertia: torch.Tensor + + # Initialize default attributes + self._attributes: dict = {} + + @property + def count(self) -> int: + return self._count + + @property + def shared_metatype(self) -> MockSharedMetaDataType: + return self._shared_metatype + + def get_dof_positions(self) -> torch.Tensor: + return self._dof_positions + + def get_dof_velocities(self) -> torch.Tensor: + return self._dof_velocities + + def get_root_transforms(self) -> torch.Tensor: + return self._root_transforms + + def get_root_velocities(self) -> torch.Tensor: + return self._root_velocities + + def get_link_transforms(self) -> torch.Tensor: + return self._link_transforms + + def get_link_velocities(self) -> torch.Tensor: + return self._link_velocities + + def get_link_acceleration(self) -> torch.Tensor: + return self._link_acceleration + + def get_coms(self) -> torch.Tensor: + return self._body_com + + def get_masses(self) -> torch.Tensor: + return self._body_mass + + def get_inertias(self) -> torch.Tensor: + return self._body_inertia + + def set_mock_data( + self, + root_transforms: torch.Tensor, + root_velocities: torch.Tensor, + link_transforms: torch.Tensor, + link_velocities: torch.Tensor, + body_com: torch.Tensor, + link_acceleration: torch.Tensor | None = None, + dof_positions: torch.Tensor | None = None, + dof_velocities: torch.Tensor | None = None, + body_mass: torch.Tensor | None = None, + body_inertia: torch.Tensor | None = None, + ): + """Set mock simulation data.""" + self._root_transforms = root_transforms + self._root_velocities = root_velocities + self._link_transforms = link_transforms + self._link_velocities = link_velocities + if link_acceleration is None: + self._link_acceleration = torch.zeros_like(link_velocities) + else: + self._link_acceleration = link_acceleration + self._body_com = body_com + + # Set attributes that ArticulationData expects + self._attributes["body_com"] = body_com + + if dof_positions is None: + self._dof_positions = torch.zeros( + (self._count, self._joint_dof_count), dtype=torch.float32, device=self._device + ) + else: + self._dof_positions = dof_positions + + if dof_velocities is None: + self._dof_velocities = torch.zeros( + (self._count, self._joint_dof_count), dtype=torch.float32, device=self._device + ) + else: + self._dof_velocities = dof_velocities + + if body_mass is None: + self._body_mass = torch.ones((self._count, self._link_count), dtype=torch.float32, device=self._device) + else: + self._body_mass = body_mass + self._attributes["body_mass"] = self._body_mass + + if body_inertia is None: + # Initialize as identity inertia tensors + self._body_inertia = torch.zeros( + (self._count, self._link_count, 9), dtype=torch.float32, device=self._device + ) + else: + self._body_inertia = body_inertia + self._attributes["body_inertia"] = self._body_inertia + + # Initialize other required attributes with defaults + self._attributes["joint_limit_lower"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["joint_limit_upper"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["joint_target_ke"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["joint_target_kd"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["joint_armature"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["joint_friction"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["joint_velocity_limit"] = wp.ones( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["joint_effort_limit"] = wp.ones( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["body_f"] = wp.zeros( + (self._count, self._link_count), dtype=wp.spatial_vectorf, device=self._device + ) + self._attributes["joint_f"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["joint_target_pos"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + self._attributes["joint_target_vel"] = wp.zeros( + (self._count, self._joint_dof_count), dtype=wp.float32, device=self._device + ) + + +def create_mock_newton_manager(gravity: tuple[float, float, float] = (0.0, 0.0, -9.81)): + """Create a mock NewtonManager for testing. + + Returns a context manager that patches the NewtonManager. + """ + mock_model = MockNewtonModel(gravity) + mock_state = MagicMock() + mock_control = MagicMock() + + return patch( + "isaaclab_newton.assets.articulation.articulation_data.NewtonManager", + **{ + "get_model.return_value": mock_model, + "get_state_0.return_value": mock_state, + "get_control.return_value": mock_control, + "get_dt.return_value": 0.01, + }, + ) diff --git a/source/isaaclab_newton/test/assets/articulation/test_articulation.py b/source/isaaclab_newton/test/assets/articulation/test_articulation.py new file mode 100644 index 00000000000..1129f238a46 --- /dev/null +++ b/source/isaaclab_newton/test/assets/articulation/test_articulation.py @@ -0,0 +1,3829 @@ +# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for Articulation class using mocked dependencies. + +This module provides unit tests for the Articulation class that bypass the heavy +initialization process (`_initialize_impl`) which requires a USD stage and real +simulation infrastructure. + +The key technique is to: +1. Create the Articulation object without calling __init__ using object.__new__ +2. Manually set up the required internal state with mock objects +3. Test individual methods in isolation + +This allows testing the mathematical operations and return values without +requiring full simulation integration. +""" + +from __future__ import annotations + +import torch +from unittest.mock import MagicMock, patch + +import pytest +import warp as wp +from isaaclab_newton.assets.articulation.articulation import Articulation +from isaaclab_newton.assets.articulation.articulation_data import ArticulationData +from isaaclab_newton.kernels import vec13f + +from isaaclab.assets.articulation.articulation_cfg import ArticulationCfg + +# TODO: Move these functions to the test utils so they can't be changed in the future. +from isaaclab.utils.math import combine_frame_transforms, quat_apply, quat_inv + +# Import mock classes from shared module +from .mock_interface import MockNewtonArticulationView, MockNewtonModel + +# Initialize Warp +wp.init() + + +## +# Test Factory - Creates Articulation instances without full initialization +## + + +def create_test_articulation( + num_instances: int = 2, + num_joints: int = 6, + num_bodies: int = 7, + device: str = "cuda:0", + is_fixed_base: bool = False, + joint_names: list[str] | None = None, + body_names: list[str] | None = None, + soft_joint_pos_limit_factor: float = 1.0, +) -> tuple[Articulation, MockNewtonArticulationView, MagicMock]: + """Create a test Articulation instance with mocked dependencies. + + This factory bypasses _initialize_impl and manually sets up the internal state, + allowing unit testing of individual methods without requiring USD/simulation. + + Args: + num_instances: Number of environment instances. + num_joints: Number of joints in the articulation. + num_bodies: Number of bodies in the articulation. + device: Device to use ("cpu" or "cuda:0"). + is_fixed_base: Whether the articulation is fixed-base. + joint_names: Custom joint names. Defaults to ["joint_0", "joint_1", ...]. + body_names: Custom body names. Defaults to ["body_0", "body_1", ...]. + soft_joint_pos_limit_factor: Soft joint position limit factor. + + Returns: + A tuple of (articulation, mock_view, mock_newton_manager). + """ + # Generate default names if not provided + if joint_names is None: + joint_names = [f"joint_{i}" for i in range(num_joints)] + if body_names is None: + body_names = [f"body_{i}" for i in range(num_bodies)] + + # Create the Articulation without calling __init__ + articulation = object.__new__(Articulation) + + # Set up the configuration + articulation.cfg = ArticulationCfg( + prim_path="/World/Robot", + soft_joint_pos_limit_factor=soft_joint_pos_limit_factor, + actuators={}, + ) + + # Set up the mock view with all parameters + mock_view = MockNewtonArticulationView( + num_instances=num_instances, + num_bodies=num_bodies, + num_joints=num_joints, + device=device, + is_fixed_base=is_fixed_base, + joint_names=joint_names, + body_names=body_names, + ) + mock_view.set_mock_data() + + # Set the view on the articulation (using object.__setattr__ to bypass type checking) + object.__setattr__(articulation, "_root_view", mock_view) + object.__setattr__(articulation, "_device", device) + + # Create mock NewtonManager + mock_newton_manager = MagicMock() + mock_model = MockNewtonModel() + mock_state = MagicMock() + mock_control = MagicMock() + mock_newton_manager.get_model.return_value = mock_model + mock_newton_manager.get_state_0.return_value = mock_state + mock_newton_manager.get_control.return_value = mock_control + mock_newton_manager.get_dt.return_value = 0.01 + + # Create ArticulationData with the mock view + with patch("isaaclab_newton.assets.articulation.articulation_data.NewtonManager", mock_newton_manager): + data = ArticulationData(mock_view, device) + # Set the names on the data object (normally done by Articulation._initialize_impl) + data.joint_names = joint_names + data.body_names = body_names + object.__setattr__(articulation, "_data", data) + + return articulation, mock_view, mock_newton_manager + + +## +# Test Fixtures +## + + +@pytest.fixture +def mock_newton_manager(): + """Create mock NewtonManager with necessary methods.""" + mock_model = MockNewtonModel() + mock_state = MagicMock() + mock_control = MagicMock() + + # Patch where NewtonManager is used (in the articulation module) + with patch("isaaclab_newton.assets.articulation.articulation.NewtonManager") as MockManager: + MockManager.get_model.return_value = mock_model + MockManager.get_state_0.return_value = mock_state + MockManager.get_control.return_value = mock_control + MockManager.get_dt.return_value = 0.01 + yield MockManager + + +@pytest.fixture +def test_articulation(): + """Create a test articulation with default parameters.""" + articulation, mock_view, mock_manager = create_test_articulation() + yield articulation, mock_view, mock_manager + + +## +# Test Cases -- Properties +## + + +class TestProperties: + """Tests for Articulation properties. + + Tests the following properties: + - data + - num_instances + - is_fixed_base + - num_joints + - num_fixed_tendons + - num_spatial_tendons + - num_bodies + - joint_names + - body_names + """ + + @pytest.mark.parametrize("num_instances", [1, 2, 4]) + def test_num_instances(self, num_instances: int): + """Test the num_instances property returns correct count.""" + articulation, _, _ = create_test_articulation(num_instances=num_instances) + assert articulation.num_instances == num_instances + + @pytest.mark.parametrize("num_joints", [1, 6]) + def test_num_joints(self, num_joints: int): + """Test the num_joints property returns correct count.""" + articulation, _, _ = create_test_articulation(num_joints=num_joints) + assert articulation.num_joints == num_joints + + @pytest.mark.parametrize("num_bodies", [1, 7]) + def test_num_bodies(self, num_bodies: int): + """Test the num_bodies property returns correct count.""" + articulation, _, _ = create_test_articulation(num_bodies=num_bodies) + assert articulation.num_bodies == num_bodies + + @pytest.mark.parametrize("is_fixed_base", [True, False]) + def test_is_fixed_base(self, is_fixed_base: bool): + """Test the is_fixed_base property.""" + articulation, _, _ = create_test_articulation(is_fixed_base=is_fixed_base) + assert articulation.is_fixed_base == is_fixed_base + + # TODO: Update when tendons are supported in Newton. + def test_num_fixed_tendons(self): + """Test that num_fixed_tendons returns 0 (not supported in Newton).""" + articulation, _, _ = create_test_articulation() + # Always returns 0 because fixed tendons are not supported in Newton. + assert articulation.num_fixed_tendons == 0 + + # TODO: Update when tendons are supported in Newton. + def test_num_spatial_tendons(self): + """Test that num_spatial_tendons returns 0 (not supported in Newton).""" + articulation, _, _ = create_test_articulation() + # Always returns 0 because spatial tendons are not supported in Newton. + assert articulation.num_spatial_tendons == 0 + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def testdata_property(self, device: str): + """Test that data property returns ArticulationData instance.""" + articulation, _, _ = create_test_articulation(device=device) + assert isinstance(articulation.data, ArticulationData) + + def test_joint_names(self): + """Test that joint_names returns the correct names.""" + custom_names = ["shoulder", "elbow", "wrist"] + articulation, _, _ = create_test_articulation( + num_joints=3, + joint_names=custom_names, + ) + assert articulation.joint_names == custom_names + + def test_body_names(self): + """Test that body_names returns the correct names.""" + custom_names = ["base", "link1", "link2", "end_effector"] + articulation, _, _ = create_test_articulation( + num_bodies=4, + body_names=custom_names, + ) + assert articulation.body_names == custom_names + + +## +# Test Cases -- Reset +## + + +class TestReset: + """Tests for reset method.""" + + def test_reset(self): + """Test that reset method works properly.""" + articulation, _, _ = create_test_articulation() + articulation.set_external_force_and_torque( + forces=torch.ones(articulation.num_instances, articulation.num_bodies, 3), + torques=torch.ones(articulation.num_instances, articulation.num_bodies, 3), + env_ids=slice(None), + body_ids=slice(None), + body_mask=None, + env_mask=None, + is_global=False, + ) + assert wp.to_torch(articulation.data._sim_bind_body_external_wrench).allclose( + torch.ones_like(wp.to_torch(articulation.data._sim_bind_body_external_wrench)) + ) + articulation.reset() + assert wp.to_torch(articulation.data._sim_bind_body_external_wrench).allclose( + torch.zeros_like(wp.to_torch(articulation.data._sim_bind_body_external_wrench)) + ) + + +## +# Test Cases -- Write Data to Sim. Skipped, this is mostly an integration test. +## + + +## +# Test Cases -- Update +## + + +class TestUpdate: + """Tests for update method.""" + + def test_update(self): + """Test that update method updates the simulation timestamp properly.""" + articulation, _, _ = create_test_articulation() + articulation.update(dt=0.01) + assert articulation.data._sim_timestamp == 0.01 + + +## +# Test Cases -- Finders +## + + +class TestFinders: + """Tests for finder methods.""" + + @pytest.mark.parametrize( + "body_names", + [["body_0", "body_1", "body_2"], ["body_3", "body_4", "body_5"], ["body_1", "body_3", "body_5"], "body_.*"], + ) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_find_bodies(self, body_names: list[str], device: str): + """Test that find_bodies method works properly.""" + articulation, _, _ = create_test_articulation(device=device) + mask, names, indices = articulation.find_bodies(body_names) + if body_names == ["body_0", "body_1", "body_2"]: + mask_ref = torch.zeros((7,), dtype=torch.bool, device=device) + mask_ref[:3] = True + assert names == ["body_0", "body_1", "body_2"] + assert indices == [0, 1, 2] + assert wp.to_torch(mask).allclose(mask_ref) + elif body_names == ["body_3", "body_4", "body_5"]: + mask_ref = torch.zeros((7,), dtype=torch.bool, device=device) + mask_ref[3:6] = True + assert names == ["body_3", "body_4", "body_5"] + assert indices == [3, 4, 5] + assert wp.to_torch(mask).allclose(mask_ref) + elif body_names == ["body_1", "body_3", "body_5"]: + mask_ref = torch.zeros((7,), dtype=torch.bool, device=device) + mask_ref[1] = True + mask_ref[3] = True + mask_ref[5] = True + assert names == ["body_1", "body_3", "body_5"] + assert indices == [1, 3, 5] + assert wp.to_torch(mask).allclose(mask_ref) + elif body_names == "body_.*": + mask_ref = torch.ones((7,), dtype=torch.bool, device=device) + assert names == ["body_0", "body_1", "body_2", "body_3", "body_4", "body_5", "body_6"] + assert indices == [0, 1, 2, 3, 4, 5, 6] + assert wp.to_torch(mask).allclose(mask_ref) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_find_body_with_preserve_order(self, device: str): + """Test that find_bodies method works properly with preserve_order.""" + articulation, _, _ = create_test_articulation(device=device) + mask, names, indices = articulation.find_bodies(["body_5", "body_3", "body_1"], preserve_order=True) + assert names == ["body_5", "body_3", "body_1"] + assert indices == [5, 3, 1] + mask_ref = torch.zeros((7,), dtype=torch.bool, device=device) + mask_ref[1] = True + mask_ref[3] = True + mask_ref[5] = True + assert wp.to_torch(mask).allclose(mask_ref) + + @pytest.mark.parametrize( + "joint_names", + [ + ["joint_0", "joint_1", "joint_2"], + ["joint_3", "joint_4", "joint_5"], + ["joint_1", "joint_3", "joint_5"], + "joint_.*", + ], + ) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_find_joints(self, joint_names: list[str], device: str): + """Test that find_joints method works properly.""" + articulation, _, _ = create_test_articulation(device=device) + mask, names, indices = articulation.find_joints(joint_names) + if joint_names == ["joint_0", "joint_1", "joint_2"]: + mask_ref = torch.zeros((6,), dtype=torch.bool, device=device) + mask_ref[:3] = True + assert names == ["joint_0", "joint_1", "joint_2"] + assert indices == [0, 1, 2] + assert wp.to_torch(mask).allclose(mask_ref) + elif joint_names == ["joint_3", "joint_4", "joint_5"]: + mask_ref = torch.zeros((6,), dtype=torch.bool, device=device) + mask_ref[3:6] = True + assert names == ["joint_3", "joint_4", "joint_5"] + assert indices == [3, 4, 5] + assert wp.to_torch(mask).allclose(mask_ref) + elif joint_names == ["joint_1", "joint_3", "joint_5"]: + mask_ref = torch.zeros((6,), dtype=torch.bool, device=device) + mask_ref[1] = True + mask_ref[3] = True + mask_ref[5] = True + assert names == ["joint_1", "joint_3", "joint_5"] + assert indices == [1, 3, 5] + assert wp.to_torch(mask).allclose(mask_ref) + elif joint_names == "joint_.*": + mask_ref = torch.ones((6,), dtype=torch.bool, device=device) + assert names == ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5"] + assert indices == [0, 1, 2, 3, 4, 5] + assert wp.to_torch(mask).allclose(mask_ref) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_find_joints_with_preserve_order(self, device: str): + """Test that find_joints method works properly with preserve_order.""" + articulation, _, _ = create_test_articulation(device=device) + mask, names, indices = articulation.find_joints(["joint_5", "joint_3", "joint_1"], preserve_order=True) + assert names == ["joint_5", "joint_3", "joint_1"] + assert indices == [5, 3, 1] + mask_ref = torch.zeros((6,), dtype=torch.bool, device=device) + mask_ref[1] = True + mask_ref[3] = True + mask_ref[5] = True + assert wp.to_torch(mask).allclose(mask_ref) + + # TODO: Update when tendons are supported in Newton. + def test_find_fixed_tendons(self): + """Test that find_fixed_tendons method works properly.""" + articulation, _, _ = create_test_articulation() + with pytest.raises(NotImplementedError): + articulation.find_fixed_tendons(["tendon_0", "tendon_1", "tendon_2"]) + + # TODO: Update when tendons are supported in Newton. + def test_find_spatial_tendons(self): + """Test that find_spatial_tendons method works properly.""" + articulation, _, _ = create_test_articulation() + with pytest.raises(NotImplementedError): + articulation.find_spatial_tendons(["tendon_0", "tendon_1", "tendon_2"]) + + +## +# Test Cases -- State Writers +## + + +class TestStateWriters: + """Tests for state writing methods.""" + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0], torch.tensor([0, 1, 2], dtype=torch.int32)]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_state_to_sim_torch(self, device: str, env_ids, num_instances: int): + """Test that write_root_state_to_sim method works properly.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + # Set a non-zero body CoM offset to test the velocity transformation + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + articulation.write_root_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_state_w).allclose(data, atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + data = torch.rand((len(env_ids), 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + articulation.write_root_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_state_w)[env_ids].allclose(data, atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, slice): + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + articulation.write_root_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_state_w)[env_ids].allclose(data, atol=1e-6, rtol=1e-6) + else: + # Update envs 0, 1, 2 + data = torch.rand((3, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + env_ids = env_ids.to(device=device) + articulation.write_root_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_state_w)[env_ids].allclose(data, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_state_to_sim_warp(self, device: str, env_ids, num_instances: int): + """Test that write_root_state_to_sim method works properly.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + # Set a non-zero body CoM offset to test the velocity transformation + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + articulation.write_root_state_to_sim(wp.from_torch(data, dtype=vec13f)) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_state_w).allclose(data, atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + data = torch.rand((len(env_ids), 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Generate warp data + data_warp = torch.ones((num_instances, 13), device=device) + mask_warp = torch.zeros((num_instances,), dtype=torch.bool, device=device) + mask_warp[env_ids] = True + data_warp[env_ids] = data + data_warp = wp.from_torch(data_warp, dtype=vec13f) + mask_warp = wp.from_torch(mask_warp, dtype=wp.bool) + # Write to simulation + articulation.write_root_state_to_sim(data_warp, env_mask=mask_warp) + # Check results + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_state_w)[env_ids].allclose(data, atol=1e-6, rtol=1e-6) + else: + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Generate warp data + data_warp = wp.from_torch(data.clone(), dtype=vec13f) + mask_warp = wp.ones((num_instances,), dtype=wp.bool, device=device) + articulation.write_root_state_to_sim(data_warp, env_mask=mask_warp) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_state_w).allclose(data, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0], torch.tensor([0, 1, 2], dtype=torch.int32)]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_com_state_to_sim_torch(self, device: str, env_ids, num_instances: int): + """Test that write_root_com_state_to_sim method works properly.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + # Set a non-zero body CoM offset to test the velocity transformation + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Write to simulation + articulation.write_root_com_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(data[:, 7:13], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_com_pose_w).allclose(data[:, :7], atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + # Update selected envs + data = torch.rand((len(env_ids), 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Write to simulation + articulation.write_root_com_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_com_vel_w)[env_ids].allclose( + data[:, 7:13], atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.root_com_pose_w)[env_ids].allclose( + data[:, :7], atol=1e-6, rtol=1e-6 + ) + elif isinstance(env_ids, slice): + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + articulation.write_root_com_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(data[:, 7:13], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_com_pose_w).allclose(data[:, :7], atol=1e-6, rtol=1e-6) + else: + # Update envs 0, 1, 2 + data = torch.rand((3, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + env_ids = env_ids.to(device=device) + articulation.write_root_com_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_com_vel_w)[env_ids].allclose( + data[:, 7:13], atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.root_com_pose_w)[env_ids].allclose( + data[:, :7], atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_com_state_to_sim_warp(self, device: str, env_ids, num_instances: int): + """Test that write_root_com_state_to_sim method works properly with warp arrays.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + # Set a non-zero body CoM offset to test the velocity transformation + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + articulation.write_root_com_state_to_sim(wp.from_torch(data, dtype=vec13f)) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(data[:, 7:13], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_com_pose_w).allclose(data[:, :7], atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + data = torch.rand((len(env_ids), 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Generate warp data + data_warp = torch.ones((num_instances, 13), device=device) + mask_warp = torch.zeros((num_instances,), dtype=torch.bool, device=device) + mask_warp[env_ids] = True + data_warp[env_ids] = data + data_warp = wp.from_torch(data_warp, dtype=vec13f) + mask_warp = wp.from_torch(mask_warp, dtype=wp.bool) + # Write to simulation + articulation.write_root_com_state_to_sim(data_warp, env_mask=mask_warp) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_com_vel_w)[env_ids].allclose( + data[:, 7:13], atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.root_com_pose_w)[env_ids].allclose( + data[:, :7], atol=1e-6, rtol=1e-6 + ) + else: + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Generate warp data + data_warp = wp.from_torch(data.clone(), dtype=vec13f) + mask_warp = wp.ones((num_instances,), dtype=wp.bool, device=device) + # Generate reference data + articulation.write_root_com_state_to_sim(data_warp, env_mask=mask_warp) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(data[:, 7:13], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_com_pose_w).allclose(data[:, :7], atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_link_state_to_sim_torch(self, device: str, env_ids, num_instances: int): + """Test that write_root_link_state_to_sim method works properly.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + # Set a non-zero body CoM offset to test the velocity transformation + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Write to simulation + articulation.write_root_link_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_vel_w).allclose(data[:, 7:13], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_link_pose_w).allclose(data[:, :7], atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + # Update selected envs + data = torch.rand((len(env_ids), 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Write to simulation + articulation.write_root_link_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_vel_w)[env_ids].allclose( + data[:, 7:13], atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.root_link_pose_w)[env_ids].allclose( + data[:, :7], atol=1e-6, rtol=1e-6 + ) + elif isinstance(env_ids, slice): + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + articulation.write_root_link_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_vel_w).allclose(data[:, 7:13], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_link_pose_w).allclose(data[:, :7], atol=1e-6, rtol=1e-6) + else: + # Update envs 0, 1, 2 + data = torch.rand((3, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + env_ids = env_ids.to(device=device) + articulation.write_root_link_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_vel_w)[env_ids].allclose( + data[:, 7:13], atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.root_link_pose_w)[env_ids].allclose( + data[:, :7], atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_link_state_to_sim_warp(self, device: str, env_ids, num_instances: int): + """Test that write_root_link_state_to_sim method works properly with warp arrays.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + # Set a non-zero body CoM offset to test the velocity transformation + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + articulation.write_root_link_state_to_sim(wp.from_torch(data, dtype=vec13f)) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_vel_w).allclose(data[:, 7:13], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_link_pose_w).allclose(data[:, :7], atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + data = torch.rand((len(env_ids), 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Generate warp data + data_warp = torch.ones((num_instances, 13), device=device) + mask_warp = torch.zeros((num_instances,), dtype=torch.bool, device=device) + mask_warp[env_ids] = True + data_warp[env_ids] = data + data_warp = wp.from_torch(data_warp, dtype=vec13f) + mask_warp = wp.from_torch(mask_warp, dtype=wp.bool) + # Generate reference data + data_ref = torch.zeros((num_instances, 13), device=device) + data_ref[env_ids] = data + # Write to simulation + articulation.write_root_link_state_to_sim(data_warp, env_mask=mask_warp) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_vel_w)[env_ids, :].allclose( + data[:, 7:13], atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.root_link_pose_w)[env_ids].allclose( + data[:, :7], atol=1e-6, rtol=1e-6 + ) + else: + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Generate warp data + data_warp = wp.from_torch(data.clone(), dtype=vec13f) + mask_warp = wp.ones((num_instances,), dtype=wp.bool, device=device) + # Generate reference data + articulation.write_root_link_state_to_sim(data_warp, env_mask=mask_warp) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_vel_w).allclose(data[:, 7:13], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_link_pose_w).allclose(data[:, :7], atol=1e-6, rtol=1e-6) + + +class TestVelocityWriters: + """Tests for velocity writing methods. + + Tests methods like: + - write_root_link_velocity_to_sim + - write_root_com_velocity_to_sim + """ + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_link_state_to_sim_torch(self, device: str, env_ids, num_instances: int): + """Test that write_root_link_state_to_sim method works properly.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + + # Set a non-zero body CoM offset to test the velocity transformation + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 6), device=device) + # Write to simulation + articulation.write_root_link_velocity_to_sim(data, env_ids=env_ids) + assert wp.to_torch(articulation.data.root_link_vel_w).allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + quat = wp.to_torch(articulation.data.root_link_quat_w) + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[:, 0, :] + # transform input velocity to center of mass frame + root_com_velocity = data.clone() + root_com_velocity[:, :3] += torch.linalg.cross( + root_com_velocity[:, 3:], quat_apply(quat, com_pos_b), dim=-1 + ) + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(root_com_velocity, atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + # Update selected envs + data = torch.rand((len(env_ids), 6), device=device) + # Write to simulation + articulation.write_root_link_velocity_to_sim(data, env_ids=env_ids) + assert wp.to_torch(articulation.data.root_link_vel_w)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + quat = wp.to_torch(articulation.data.root_link_quat_w)[env_ids] + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[env_ids, 0, :] + # transform input velocity to center of mass frame + root_com_velocity = data.clone() + root_com_velocity[:, :3] += torch.linalg.cross( + root_com_velocity[:, 3:], quat_apply(quat, com_pos_b), dim=-1 + ) + assert wp.to_torch(articulation.data.root_com_vel_w)[env_ids, :].allclose( + root_com_velocity, atol=1e-6, rtol=1e-6 + ) + elif isinstance(env_ids, slice): + # Update all envs + data = torch.rand((num_instances, 6), device=device) + articulation.write_root_link_velocity_to_sim(data, env_ids=env_ids) + assert wp.to_torch(articulation.data.root_link_vel_w).allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + quat = wp.to_torch(articulation.data.root_link_quat_w) + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[:, 0, :] + # transform input velocity to center of mass frame + root_com_velocity = data.clone() + root_com_velocity[:, :3] += torch.linalg.cross( + root_com_velocity[:, 3:], quat_apply(quat, com_pos_b), dim=-1 + ) + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(root_com_velocity, atol=1e-6, rtol=1e-6) + else: + # Update envs 0, 1, 2 + data = torch.rand((3, 6), device=device) + env_ids = env_ids.to(device=device) + articulation.write_root_link_velocity_to_sim(data, env_ids=env_ids) + assert wp.to_torch(articulation.data.root_link_vel_w)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + quat = wp.to_torch(articulation.data.root_link_quat_w)[env_ids] + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[env_ids, 0, :] + # transform input velocity to center of mass frame + root_com_velocity = data.clone() + root_com_velocity[:, :3] += torch.linalg.cross( + root_com_velocity[:, 3:], quat_apply(quat, com_pos_b), dim=-1 + ) + assert wp.to_torch(articulation.data.root_com_vel_w)[env_ids, :].allclose( + root_com_velocity, atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_link_velocity_to_sim_with_warp(self, device: str, env_ids, num_instances: int): + """Test that write_root_link_velocity_to_sim method works properly with warp arrays.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + + # Set a non-zero body CoM offset + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 6), device=device) + articulation.write_root_link_velocity_to_sim(wp.from_torch(data, dtype=wp.spatial_vectorf)) + assert wp.to_torch(articulation.data.root_link_vel_w).allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + quat = wp.to_torch(articulation.data.root_link_quat_w) + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[:, 0, :] + # transform input velocity to center of mass frame + root_com_velocity = data.clone() + root_com_velocity[:, :3] += torch.linalg.cross( + root_com_velocity[:, 3:], quat_apply(quat, com_pos_b), dim=-1 + ) + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(root_com_velocity, atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + data = torch.rand((len(env_ids), 6), device=device) + # Generate warp data + data_warp = torch.ones((num_instances, 6), device=device) + mask_warp = torch.zeros((num_instances,), dtype=torch.bool, device=device) + mask_warp[env_ids] = True + data_warp[env_ids] = data + data_warp = wp.from_torch(data_warp, dtype=wp.spatial_vectorf) + mask_warp = wp.from_torch(mask_warp, dtype=wp.bool) + # Write to simulation + articulation.write_root_link_velocity_to_sim(data_warp, env_mask=mask_warp) + assert wp.to_torch(articulation.data.root_link_vel_w)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + quat = wp.to_torch(articulation.data.root_link_quat_w)[env_ids] + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[env_ids, 0, :] + # transform input velocity to center of mass frame + root_com_velocity = data.clone() + root_com_velocity[:, :3] += torch.linalg.cross( + root_com_velocity[:, 3:], quat_apply(quat, com_pos_b), dim=-1 + ) + assert wp.to_torch(articulation.data.root_com_vel_w)[env_ids, :].allclose( + root_com_velocity, atol=1e-6, rtol=1e-6 + ) + else: + # Update all envs + data = torch.rand((num_instances, 6), device=device) + # Generate warp data + data_warp = wp.from_torch(data.clone(), dtype=wp.spatial_vectorf) + mask_warp = wp.ones((num_instances,), dtype=wp.bool, device=device) + # Generate reference data + articulation.write_root_link_velocity_to_sim(data_warp, env_mask=mask_warp) + assert wp.to_torch(articulation.data.root_link_vel_w).allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + quat = wp.to_torch(articulation.data.root_link_quat_w) + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[:, 0, :] + # transform input velocity to center of mass frame + root_com_velocity = data.clone() + root_com_velocity[:, :3] += torch.linalg.cross( + root_com_velocity[:, 3:], quat_apply(quat, com_pos_b), dim=-1 + ) + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(root_com_velocity, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_com_state_to_sim_torch(self, device: str, env_ids, num_instances: int): + """Test that write_root_com_state_to_sim method works properly.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + + # Set a non-zero body CoM offset to test the velocity transformation + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 6), device=device) + # Write to simulation + articulation.write_root_com_velocity_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert torch.all(wp.to_torch(articulation.data.root_link_vel_w)[:, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_link_vel_w)[:, 3:].allclose(data[:, 3:], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(data, atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + # Update selected envs + data = torch.rand((len(env_ids), 6), device=device) + # Write to simulation + articulation.write_root_com_velocity_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert torch.all(wp.to_torch(articulation.data.root_link_vel_w)[env_ids, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_link_vel_w)[env_ids, 3:].allclose( + data[:, 3:], atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.root_com_vel_w)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, slice): + # Update all envs + data = torch.rand((num_instances, 6), device=device) + articulation.write_root_com_velocity_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert torch.all(wp.to_torch(articulation.data.root_link_vel_w)[:, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_link_vel_w)[:, 3:].allclose(data[:, 3:], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(data, atol=1e-6, rtol=1e-6) + else: + # Update envs 0, 1, 2 + data = torch.rand((3, 6), device=device) + env_ids = env_ids.to(device=device) + articulation.write_root_com_velocity_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert torch.all(wp.to_torch(articulation.data.root_link_vel_w)[env_ids, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_link_vel_w)[env_ids, 3:].allclose( + data[:, 3:], atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.root_com_vel_w)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_com_velocity_to_sim_with_warp(self, device: str, env_ids, num_instances: int): + """Test that write_root_com_velocity_to_sim method works properly with warp arrays.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + + # Set a non-zero body CoM offset + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 6), device=device) + articulation.write_root_com_velocity_to_sim(wp.from_torch(data, dtype=wp.spatial_vectorf)) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert torch.all(wp.to_torch(articulation.data.root_link_vel_w)[:, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_link_vel_w)[:, 3:].allclose(data[:, 3:], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(data, atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + data = torch.rand((len(env_ids), 6), device=device) + # Generate warp data + data_warp = torch.ones((num_instances, 6), device=device) + mask_warp = torch.zeros((num_instances,), dtype=torch.bool, device=device) + mask_warp[env_ids] = True + data_warp[env_ids] = data + data_warp = wp.from_torch(data_warp, dtype=wp.spatial_vectorf) + mask_warp = wp.from_torch(mask_warp, dtype=wp.bool) + # Write to simulation + articulation.write_root_com_velocity_to_sim(data_warp, env_mask=mask_warp) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert torch.all(wp.to_torch(articulation.data.root_link_vel_w)[env_ids, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_link_vel_w)[env_ids, 3:].allclose( + data[:, 3:], atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.root_com_vel_w)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + else: + # Update all envs + data = torch.rand((num_instances, 6), device=device) + # Generate warp data + data_warp = wp.from_torch(data.clone(), dtype=wp.spatial_vectorf) + mask_warp = wp.ones((num_instances,), dtype=wp.bool, device=device) + # Generate reference data + articulation.write_root_com_velocity_to_sim(data_warp, env_mask=mask_warp) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert torch.all(wp.to_torch(articulation.data.root_link_vel_w)[:, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_link_vel_w)[:, 3:].allclose(data[:, 3:], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(data, atol=1e-6, rtol=1e-6) + + +class TestPoseWriters: + """Tests for pose writing methods. + + Tests methods like: + - write_root_link_pose_to_sim + - write_root_com_pose_to_sim + """ + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_link_pose_to_sim_torch(self, device: str, env_ids, num_instances: int): + """Test that write_root_link_pose_to_sim method works properly.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + # Set a non-zero body CoM offset to test the pose transformation + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 7), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Write to simulation + articulation.write_root_link_pose_to_sim(data, env_ids=env_ids) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_pose_w).allclose(data, atol=1e-6, rtol=1e-6) + assert torch.all(wp.to_torch(articulation.data.root_com_pose_w)[:, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_com_pose_w)[:, 3:].allclose(data[:, 3:], atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + # Update selected envs + data = torch.rand((len(env_ids), 7), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Write to simulation + articulation.write_root_link_pose_to_sim(data, env_ids=env_ids) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_pose_w)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + assert torch.all(wp.to_torch(articulation.data.root_com_pose_w)[env_ids, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_com_pose_w)[env_ids, 3:].allclose( + data[:, 3:], atol=1e-6, rtol=1e-6 + ) + elif isinstance(env_ids, slice): + # Update all envs + data = torch.rand((num_instances, 7), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + articulation.write_root_link_pose_to_sim(data, env_ids=env_ids) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_pose_w).allclose(data, atol=1e-6, rtol=1e-6) + assert torch.all(wp.to_torch(articulation.data.root_com_pose_w)[:, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_com_pose_w)[:, 3:].allclose(data[:, 3:], atol=1e-6, rtol=1e-6) + else: + # Update envs 0, 1, 2 + data = torch.rand((3, 7), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + env_ids = env_ids.to(device=device) + articulation.write_root_link_pose_to_sim(data, env_ids=env_ids) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_pose_w)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + assert torch.all(wp.to_torch(articulation.data.root_com_pose_w)[env_ids, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_com_pose_w)[env_ids, 3:].allclose( + data[:, 3:], atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_link_pose_to_sim_with_warp(self, device: str, env_ids, num_instances: int): + """Test that write_root_link_pose_to_sim method works properly with warp arrays.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + # Set a non-zero body CoM offset + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if (env_ids is None) or (isinstance(env_ids, slice)): + data = torch.rand((num_instances, 7), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Update all envs + articulation.write_root_link_pose_to_sim(wp.from_torch(data, dtype=wp.transformf)) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_pose_w).allclose(data, atol=1e-6, rtol=1e-6) + assert torch.all(wp.to_torch(articulation.data.root_com_pose_w)[:, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_com_pose_w)[:, 3:].allclose(data[:, 3:], atol=1e-6, rtol=1e-6) + else: + data = torch.rand((len(env_ids), 7), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Generate warp data + data_warp = torch.ones((num_instances, 7), device=device) + mask_warp = torch.zeros((num_instances,), dtype=torch.bool, device=device) + mask_warp[env_ids] = True + data_warp[env_ids] = data + data_warp = wp.from_torch(data_warp, dtype=wp.transformf) + mask_warp = wp.from_torch(mask_warp, dtype=wp.bool) + # Write to simulation + articulation.write_root_link_pose_to_sim(data_warp, env_mask=mask_warp) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_pose_w)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + assert torch.all(wp.to_torch(articulation.data.root_com_pose_w)[env_ids, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_com_pose_w)[env_ids, 3:].allclose( + data[:, 3:], atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_com_pose_to_sim_torch(self, device: str, env_ids, num_instances: int): + """Test that write_root_com_pose_to_sim method works properly.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + # Set a non-zero body CoM offset to test the velocity transformation + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if (env_ids is None) or (isinstance(env_ids, slice)): + data = torch.rand((num_instances, 7), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Write to simulation + articulation.write_root_com_pose_to_sim(data, env_ids=env_ids) + assert wp.to_torch(articulation.data.root_com_pose_w).allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[:, 0, :] + com_quat_b = wp.to_torch(articulation.data.body_com_quat_b)[:, 0, :] + # transform input CoM pose to link frame + root_link_pos, root_link_quat = combine_frame_transforms( + data[..., :3], + data[..., 3:7], + quat_apply(quat_inv(com_quat_b), -com_pos_b), + quat_inv(com_quat_b), + ) + root_link_pose = torch.cat((root_link_pos, root_link_quat), dim=-1) + assert wp.to_torch(articulation.data.root_link_pose_w).allclose(root_link_pose, atol=1e-6, rtol=1e-6) + else: + if isinstance(env_ids, torch.Tensor): + env_ids = env_ids.to(device=device) + # Update selected envs + data = torch.rand((len(env_ids), 7), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Write to simulation + articulation.write_root_com_pose_to_sim(data, env_ids=env_ids) + assert wp.to_torch(articulation.data.root_com_pose_w)[env_ids].allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[env_ids, 0, :] + com_quat_b = wp.to_torch(articulation.data.body_com_quat_b)[env_ids, 0, :] + # transform input CoM pose to link frame + root_link_pos, root_link_quat = combine_frame_transforms( + data[..., :3], + data[..., 3:7], + quat_apply(quat_inv(com_quat_b), -com_pos_b), + quat_inv(com_quat_b), + ) + root_link_pose = torch.cat((root_link_pos, root_link_quat), dim=-1) + assert wp.to_torch(articulation.data.root_link_pose_w)[env_ids, :].allclose( + root_link_pose, atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_com_pose_to_sim_with_warp(self, device: str, env_ids, num_instances: int): + """Test that write_root_com_pose_to_sim method works properly with warp arrays.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + # Set a non-zero body CoM offset + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if (env_ids is None) or (isinstance(env_ids, slice)): + data = torch.rand((num_instances, 7), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + articulation.write_root_com_pose_to_sim(wp.from_torch(data, dtype=wp.transformf)) + assert wp.to_torch(articulation.data.root_com_pose_w).allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[:, 0, :] + com_quat_b = wp.to_torch(articulation.data.body_com_quat_b)[:, 0, :] + # transform input CoM pose to link frame + root_link_pos, root_link_quat = combine_frame_transforms( + data[..., :3], + data[..., 3:7], + quat_apply(quat_inv(com_quat_b), -com_pos_b), + quat_inv(com_quat_b), + ) + root_link_pose = torch.cat((root_link_pos, root_link_quat), dim=-1) + assert wp.to_torch(articulation.data.root_link_pose_w).allclose(root_link_pose, atol=1e-6, rtol=1e-6) + else: + data = torch.rand((len(env_ids), 7), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Generate warp data + data_warp = torch.ones((num_instances, 7), device=device) + mask_warp = torch.zeros((num_instances,), dtype=torch.bool, device=device) + mask_warp[env_ids] = True + data_warp[env_ids] = data + data_warp = wp.from_torch(data_warp, dtype=wp.transformf) + mask_warp = wp.from_torch(mask_warp, dtype=wp.bool) + # Write to simulation + articulation.write_root_com_pose_to_sim(data_warp, env_mask=mask_warp) + assert wp.to_torch(articulation.data.root_com_pose_w)[env_ids].allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[env_ids, 0, :] + com_quat_b = wp.to_torch(articulation.data.body_com_quat_b)[env_ids, 0, :] + # transform input CoM pose to link frame + root_link_pos, root_link_quat = combine_frame_transforms( + data[..., :3], + data[..., 3:7], + quat_apply(quat_inv(com_quat_b), -com_pos_b), + quat_inv(com_quat_b), + ) + root_link_pose = torch.cat((root_link_pos, root_link_quat), dim=-1) + assert wp.to_torch(articulation.data.root_link_pose_w)[env_ids, :].allclose( + root_link_pose, atol=1e-6, rtol=1e-6 + ) + + +class TestJointState: + """Tests for joint state writing methods. + + Tests methods: + - write_joint_state_to_sim + - write_joint_position_to_sim + - write_joint_velocity_to_sim + """ + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_state_to_sim_torch(self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_joints=num_joints, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_joints == 1: + if (joint_ids is not None) and (not isinstance(joint_ids, slice)): + joint_ids = [0] + + for _ in range(5): + if (env_ids is None) or (isinstance(env_ids, slice)): + if (joint_ids is None) or (isinstance(joint_ids, slice)): + # All envs and joints + data1 = torch.rand((num_instances, num_joints), device=device) + data2 = torch.rand((num_instances, num_joints), device=device) + articulation.write_joint_state_to_sim(data1, data2, env_ids=env_ids, joint_ids=joint_ids) + assert wp.to_torch(articulation.data.joint_pos).allclose(data1, atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.joint_vel).allclose(data2, atol=1e-6, rtol=1e-6) + else: + # All envs and selected joints + data1 = torch.rand((num_instances, len(joint_ids)), device=device) + data2 = torch.rand((num_instances, len(joint_ids)), device=device) + articulation.write_joint_state_to_sim(data1, data2, env_ids=env_ids, joint_ids=joint_ids) + assert wp.to_torch(articulation.data.joint_pos)[:, joint_ids].allclose(data1, atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.joint_vel)[:, joint_ids].allclose(data2, atol=1e-6, rtol=1e-6) + else: + if (joint_ids is None) or (isinstance(joint_ids, slice)): + # Selected envs and all joints + data1 = torch.rand((len(env_ids), num_joints), device=device) + data2 = torch.rand((len(env_ids), num_joints), device=device) + articulation.write_joint_state_to_sim(data1, data2, env_ids=env_ids, joint_ids=joint_ids) + assert wp.to_torch(articulation.data.joint_pos)[env_ids, :].allclose(data1, atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.joint_vel)[env_ids, :].allclose(data2, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + data1 = torch.rand((len(env_ids), len(joint_ids)), device=device) + data2 = torch.rand((len(env_ids), len(joint_ids)), device=device) + articulation.write_joint_state_to_sim(data1, data2, env_ids=env_ids, joint_ids=joint_ids) + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + assert wp.to_torch(articulation.data.joint_pos)[env_ids_, joint_ids].allclose( + data1, atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.joint_vel)[env_ids_, joint_ids].allclose( + data2, atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_state_to_sim_warp(self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_joints=num_joints, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_joints == 1: + if (joint_ids is not None) and (not isinstance(joint_ids, slice)): + joint_ids = [0] + + for _ in range(5): + if env_ids is None: + if joint_ids is None: + # All envs and joints + data1 = torch.rand((num_instances, num_joints), device=device) + data2 = torch.rand((num_instances, num_joints), device=device) + articulation.write_joint_state_to_sim( + wp.from_torch(data1, dtype=wp.float32), + wp.from_torch(data2, dtype=wp.float32), + env_mask=None, + joint_mask=None, + ) + assert wp.to_torch(articulation.data.joint_pos).allclose(data1, atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.joint_vel).allclose(data2, atol=1e-6, rtol=1e-6) + else: + # All envs and selected joints + data1 = torch.rand((num_instances, len(joint_ids)), device=device) + data2 = torch.rand((num_instances, len(joint_ids)), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[:, joint_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + data2_warp = torch.ones((num_instances, num_joints), device=device) + data2_warp[:, joint_ids] = data2 + data2_warp = wp.from_torch(data2_warp, dtype=wp.float32) + joint_mask = torch.zeros((num_joints,), dtype=torch.bool, device=device) + joint_mask[joint_ids] = True + joint_mask = wp.from_torch(joint_mask, dtype=wp.bool) + articulation.write_joint_state_to_sim(data1_warp, data2_warp, env_mask=None, joint_mask=joint_mask) + assert wp.to_torch(articulation.data.joint_pos)[:, joint_ids].allclose(data1, atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.joint_vel)[:, joint_ids].allclose(data2, atol=1e-6, rtol=1e-6) + else: + if joint_ids is None: + # Selected envs and all joints + data1 = torch.rand((len(env_ids), num_joints), device=device) + data2 = torch.rand((len(env_ids), num_joints), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[env_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + data2_warp = torch.ones((num_instances, num_joints), device=device) + data2_warp[env_ids] = data2 + data2_warp = wp.from_torch(data2_warp, dtype=wp.float32) + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + articulation.write_joint_state_to_sim( + wp.from_torch(data1, dtype=wp.float32), + wp.from_torch(data2, dtype=wp.float32), + env_mask=env_mask, + joint_mask=None, + ) + assert wp.to_torch(articulation.data.joint_pos)[env_ids, :].allclose(data1, atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.joint_vel)[env_ids, :].allclose(data2, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + data1 = torch.rand((len(env_ids), len(joint_ids)), device=device) + data2 = torch.rand((len(env_ids), len(joint_ids)), device=device) + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[env_ids_, joint_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + data2_warp = torch.ones((num_instances, num_joints), device=device) + data2_warp[env_ids_, joint_ids] = data2 + data2_warp = wp.from_torch(data2_warp, dtype=wp.float32) + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + joint_mask = torch.zeros((num_joints,), dtype=torch.bool, device=device) + joint_mask[joint_ids] = True + joint_mask = wp.from_torch(joint_mask, dtype=wp.bool) + articulation.write_joint_state_to_sim( + data1_warp, data2_warp, env_mask=env_mask, joint_mask=joint_mask + ) + assert wp.to_torch(articulation.data.joint_pos)[env_ids_, joint_ids].allclose( + data1, atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.joint_vel)[env_ids_, joint_ids].allclose( + data2, atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_position_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_joints=num_joints, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_joints == 1: + if (joint_ids is not None) and (not isinstance(joint_ids, slice)): + joint_ids = [0] + + for _ in range(5): + if (env_ids is None) or (isinstance(env_ids, slice)): + if (joint_ids is None) or (isinstance(joint_ids, slice)): + # All envs and joints + data1 = torch.rand((num_instances, num_joints), device=device) + articulation.write_joint_position_to_sim(data1, env_ids=env_ids, joint_ids=joint_ids) + assert wp.to_torch(articulation.data.joint_pos).allclose(data1, atol=1e-6, rtol=1e-6) + else: + # All envs and selected joints + data1 = torch.rand((num_instances, len(joint_ids)), device=device) + articulation.write_joint_position_to_sim(data1, env_ids=env_ids, joint_ids=joint_ids) + assert wp.to_torch(articulation.data.joint_pos)[:, joint_ids].allclose(data1, atol=1e-6, rtol=1e-6) + else: + if (joint_ids is None) or (isinstance(joint_ids, slice)): + # Selected envs and all joints + data1 = torch.rand((len(env_ids), num_joints), device=device) + articulation.write_joint_position_to_sim(data1, env_ids=env_ids, joint_ids=joint_ids) + assert wp.to_torch(articulation.data.joint_pos)[env_ids, :].allclose(data1, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + data1 = torch.rand((len(env_ids), len(joint_ids)), device=device) + articulation.write_joint_position_to_sim(data1, env_ids=env_ids, joint_ids=joint_ids) + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + assert wp.to_torch(articulation.data.joint_pos)[env_ids_, joint_ids].allclose( + data1, atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_position_to_sim_warp( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_joints=num_joints, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_joints == 1: + if (joint_ids is not None) and (not isinstance(joint_ids, slice)): + joint_ids = [0] + + for _ in range(5): + if env_ids is None: + if joint_ids is None: + # All envs and joints + data1 = torch.rand((num_instances, num_joints), device=device) + articulation.write_joint_position_to_sim( + wp.from_torch(data1, dtype=wp.float32), env_mask=None, joint_mask=None + ) + assert wp.to_torch(articulation.data.joint_pos).allclose(data1, atol=1e-6, rtol=1e-6) + else: + # All envs and selected joints + data1 = torch.rand((num_instances, len(joint_ids)), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[:, joint_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + joint_mask = torch.zeros((num_joints,), dtype=torch.bool, device=device) + joint_mask[joint_ids] = True + joint_mask = wp.from_torch(joint_mask, dtype=wp.bool) + articulation.write_joint_position_to_sim(data1_warp, env_mask=None, joint_mask=joint_mask) + assert wp.to_torch(articulation.data.joint_pos)[:, joint_ids].allclose(data1, atol=1e-6, rtol=1e-6) + else: + if joint_ids is None: + # Selected envs and all joints + data1 = torch.rand((len(env_ids), num_joints), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[env_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + articulation.write_joint_position_to_sim(data1_warp, env_mask=env_mask, joint_mask=None) + assert wp.to_torch(articulation.data.joint_pos)[env_ids, :].allclose(data1, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + data1 = torch.rand((len(env_ids), len(joint_ids)), device=device) + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[env_ids_, joint_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + joint_mask = torch.zeros((num_joints,), dtype=torch.bool, device=device) + joint_mask[joint_ids] = True + joint_mask = wp.from_torch(joint_mask, dtype=wp.bool) + articulation.write_joint_position_to_sim(data1_warp, env_mask=env_mask, joint_mask=joint_mask) + assert wp.to_torch(articulation.data.joint_pos)[env_ids_, joint_ids].allclose( + data1, atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_velocity_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_joints=num_joints, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_joints == 1: + if (joint_ids is not None) and (not isinstance(joint_ids, slice)): + joint_ids = [0] + + for _ in range(5): + if (env_ids is None) or (isinstance(env_ids, slice)): + if (joint_ids is None) or (isinstance(joint_ids, slice)): + # All envs and joints + data1 = torch.rand((num_instances, num_joints), device=device) + articulation.write_joint_velocity_to_sim(data1, env_ids=env_ids, joint_ids=joint_ids) + assert wp.to_torch(articulation.data.joint_vel).allclose(data1, atol=1e-6, rtol=1e-6) + else: + # All envs and selected joints + data1 = torch.rand((num_instances, len(joint_ids)), device=device) + articulation.write_joint_velocity_to_sim(data1, env_ids=env_ids, joint_ids=joint_ids) + assert wp.to_torch(articulation.data.joint_vel)[:, joint_ids].allclose(data1, atol=1e-6, rtol=1e-6) + else: + if (joint_ids is None) or (isinstance(joint_ids, slice)): + # Selected envs and all joints + data1 = torch.rand((len(env_ids), num_joints), device=device) + articulation.write_joint_velocity_to_sim(data1, env_ids=env_ids, joint_ids=joint_ids) + assert wp.to_torch(articulation.data.joint_vel)[env_ids, :].allclose(data1, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + data1 = torch.rand((len(env_ids), len(joint_ids)), device=device) + articulation.write_joint_velocity_to_sim(data1, env_ids=env_ids, joint_ids=joint_ids) + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + assert wp.to_torch(articulation.data.joint_vel)[env_ids_, joint_ids].allclose( + data1, atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_velocity_to_sim_warp( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_joints=num_joints, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_joints == 1: + if (joint_ids is not None) and (not isinstance(joint_ids, slice)): + joint_ids = [0] + + for _ in range(5): + if env_ids is None: + if joint_ids is None: + # All envs and joints + data1 = torch.rand((num_instances, num_joints), device=device) + articulation.write_joint_velocity_to_sim( + wp.from_torch(data1, dtype=wp.float32), env_mask=None, joint_mask=None + ) + assert wp.to_torch(articulation.data.joint_vel).allclose(data1, atol=1e-6, rtol=1e-6) + else: + # All envs and selected joints + data1 = torch.rand((num_instances, len(joint_ids)), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[:, joint_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + joint_mask = torch.zeros((num_joints,), dtype=torch.bool, device=device) + joint_mask[joint_ids] = True + joint_mask = wp.from_torch(joint_mask, dtype=wp.bool) + articulation.write_joint_velocity_to_sim(data1_warp, env_mask=None, joint_mask=joint_mask) + assert wp.to_torch(articulation.data.joint_vel)[:, joint_ids].allclose(data1, atol=1e-6, rtol=1e-6) + else: + if joint_ids is None: + # Selected envs and all joints + data1 = torch.rand((len(env_ids), num_joints), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[env_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + articulation.write_joint_velocity_to_sim(data1_warp, env_mask=env_mask, joint_mask=None) + assert wp.to_torch(articulation.data.joint_vel)[env_ids, :].allclose(data1, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + data1 = torch.rand((len(env_ids), len(joint_ids)), device=device) + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[env_ids_, joint_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + joint_mask = torch.zeros((num_joints,), dtype=torch.bool, device=device) + joint_mask[joint_ids] = True + joint_mask = wp.from_torch(joint_mask, dtype=wp.bool) + articulation.write_joint_velocity_to_sim(data1_warp, env_mask=env_mask, joint_mask=joint_mask) + assert wp.to_torch(articulation.data.joint_vel)[env_ids_, joint_ids].allclose( + data1, atol=1e-6, rtol=1e-6 + ) + + +## +# Test Cases -- Simulation Parameters Writers. +## + + +class TestWriteJointPropertiesToSim: + """Tests for writing joint properties to the simulation. + + Tests methods: + - write_joint_stiffness_to_sim + - write_joint_damping_to_sim + - write_joint_position_limit_to_sim + - write_joint_velocity_limit_to_sim + - write_joint_effort_limit_to_sim + - write_joint_armature_to_sim + - write_joint_friction_coefficient_to_sim + - write_joint_dynamic_friction_coefficient_to_sim + - write_joint_joint_friction_to_sim + - write_joint_limits_to_sim + """ + + def generic_test_property_writer_torch( + self, + device: str, + env_ids, + joint_ids, + num_instances: int, + num_joints: int, + writer_function_name: str, + property_name: str, + ): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_joints=num_joints, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_joints == 1: + if (joint_ids is not None) and (not isinstance(joint_ids, slice)): + joint_ids = [0] + + writer_function = getattr(articulation, writer_function_name) + + for i in range(5): + if (env_ids is None) or (isinstance(env_ids, slice)): + if (joint_ids is None) or (isinstance(joint_ids, slice)): + # All envs and joints + if i % 2 == 0: + data1 = torch.rand((num_instances, num_joints), device=device) + else: + data1 = float(i) + writer_function(data1, env_ids=env_ids, joint_ids=joint_ids) + property_data = getattr(articulation.data, property_name) + if i % 2 == 0: + assert wp.to_torch(property_data).allclose(data1, atol=1e-6, rtol=1e-6) + else: + assert wp.to_torch(property_data).allclose( + data1 * torch.ones((num_instances, num_joints), device=device), atol=1e-6, rtol=1e-6 + ) + else: + # All envs and selected joints + if i % 2 == 0: + data1 = torch.rand((num_instances, len(joint_ids)), device=device) + else: + data1 = float(i) + data_ref = torch.zeros((num_instances, num_joints), device=device) + data_ref[:, joint_ids] = data1 + writer_function(data1, env_ids=env_ids, joint_ids=joint_ids) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + else: + if (joint_ids is None) or (isinstance(joint_ids, slice)): + # Selected envs and all joints + if i % 2 == 0: + data1 = torch.rand((len(env_ids), num_joints), device=device) + else: + data1 = float(i) + data_ref = torch.zeros((num_instances, num_joints), device=device) + data_ref[env_ids, :] = data1 + writer_function(data1, env_ids=env_ids, joint_ids=joint_ids) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + if i % 2 == 0: + data1 = torch.rand((len(env_ids), len(joint_ids)), device=device) + else: + data1 = float(i) + writer_function(data1, env_ids=env_ids, joint_ids=joint_ids) + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + data_ref = torch.zeros((num_instances, num_joints), device=device) + data_ref[env_ids_, joint_ids] = data1 + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + + def generic_test_property_writer_warp( + self, + device: str, + env_ids, + joint_ids, + num_instances: int, + num_joints: int, + writer_function_name: str, + property_name: str, + ): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_joints=num_joints, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_joints == 1: + if (joint_ids is not None) and (not isinstance(joint_ids, slice)): + joint_ids = [0] + + writer_function = getattr(articulation, writer_function_name) + + for i in range(5): + if env_ids is None: + if joint_ids is None: + # All envs and joints + if i % 2 == 0: + data1 = torch.rand((num_instances, num_joints), device=device) + data1_warp = wp.from_torch(data1, dtype=wp.float32) + else: + data1 = float(i) + data1_warp = data1 + writer_function(data1_warp, env_mask=None, joint_mask=None) + property_data = getattr(articulation.data, property_name) + if i % 2 == 0: + assert wp.to_torch(property_data).allclose(data1, atol=1e-6, rtol=1e-6) + else: + assert wp.to_torch(property_data).allclose( + data1 * torch.ones((num_instances, num_joints), device=device), atol=1e-6, rtol=1e-6 + ) + else: + # All envs and selected joints + if i % 2 == 0: + data1 = torch.rand((num_instances, len(joint_ids)), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[:, joint_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + else: + data1 = float(i) + data1_warp = data1 + joint_mask = torch.zeros((num_joints,), dtype=torch.bool, device=device) + joint_mask[joint_ids] = True + joint_mask = wp.from_torch(joint_mask, dtype=wp.bool) + data_ref = torch.zeros((num_instances, num_joints), device=device) + data_ref[:, joint_ids] = data1 + writer_function(data1_warp, env_mask=None, joint_mask=joint_mask) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + else: + if joint_ids is None: + # Selected envs and all joints + if i % 2 == 0: + data1 = torch.rand((len(env_ids), num_joints), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[env_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + else: + data1 = float(i) + data1_warp = data1 + data_ref = torch.zeros((num_instances, num_joints), device=device) + data_ref[env_ids, :] = data1 + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + writer_function(data1_warp, env_mask=env_mask, joint_mask=None) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + if i % 2 == 0: + data1 = torch.rand((len(env_ids), len(joint_ids)), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[env_ids_, joint_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + else: + data1 = float(i) + data1_warp = data1 + data_ref = torch.zeros((num_instances, num_joints), device=device) + data_ref[env_ids_, joint_ids] = data1 + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + joint_mask = torch.zeros((num_joints,), dtype=torch.bool, device=device) + joint_mask[joint_ids] = True + joint_mask = wp.from_torch(joint_mask, dtype=wp.bool) + writer_function(data1_warp, env_mask=env_mask, joint_mask=joint_mask) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + + def generic_test_property_writer_torch_dual( + self, + device: str, + env_ids, + joint_ids, + num_instances: int, + num_joints: int, + writer_function_name: str, + property_name: str, + ): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_joints=num_joints, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_joints == 1: + if (joint_ids is not None) and (not isinstance(joint_ids, slice)): + joint_ids = [0] + + writer_function = getattr(articulation, writer_function_name) + + for _ in range(5): + if (env_ids is None) or (isinstance(env_ids, slice)): + if (joint_ids is None) or (isinstance(joint_ids, slice)): + # All envs and joints + data1 = torch.rand((num_instances, num_joints), device=device) + data2 = torch.rand((num_instances, num_joints), device=device) + writer_function(data1, data2, env_ids=env_ids, joint_ids=joint_ids) + data = torch.cat([data1.unsqueeze(-1), data2.unsqueeze(-1)], dim=-1) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data, atol=1e-6, rtol=1e-6) + else: + # All envs and selected joints + data1 = torch.rand((num_instances, len(joint_ids)), device=device) + data2 = torch.rand((num_instances, len(joint_ids)), device=device) + writer_function(data1, data2, env_ids=env_ids, joint_ids=joint_ids) + data = torch.cat([data1.unsqueeze(-1), data2.unsqueeze(-1)], dim=-1) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data)[:, joint_ids].allclose(data, atol=1e-6, rtol=1e-6) + else: + if (joint_ids is None) or (isinstance(joint_ids, slice)): + # Selected envs and all joints + data1 = torch.rand((len(env_ids), num_joints), device=device) + data2 = torch.rand((len(env_ids), num_joints), device=device) + writer_function(data1, data2, env_ids=env_ids, joint_ids=joint_ids) + data = torch.cat([data1.unsqueeze(-1), data2.unsqueeze(-1)], dim=-1) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + data1 = torch.rand((len(env_ids), len(joint_ids)), device=device) + data2 = torch.rand((len(env_ids), len(joint_ids)), device=device) + writer_function(data1, data2, env_ids=env_ids, joint_ids=joint_ids) + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + property_data = getattr(articulation.data, property_name) + data = torch.cat([data1.unsqueeze(-1), data2.unsqueeze(-1)], dim=-1) + assert wp.to_torch(property_data)[env_ids_, joint_ids].allclose(data, atol=1e-6, rtol=1e-6) + + def generic_test_property_writer_warp_dual( + self, + device: str, + env_ids, + joint_ids, + num_instances: int, + num_joints: int, + writer_function_name: str, + property_name: str, + ): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_joints=num_joints, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_joints == 1: + if (joint_ids is not None) and (not isinstance(joint_ids, slice)): + joint_ids = [0] + + writer_function = getattr(articulation, writer_function_name) + + for _ in range(5): + if env_ids is None: + if joint_ids is None: + # All envs and joints + data1 = torch.rand((num_instances, num_joints), device=device) + data2 = torch.rand((num_instances, num_joints), device=device) + writer_function( + wp.from_torch(data1, dtype=wp.float32), + wp.from_torch(data2, dtype=wp.float32), + env_mask=None, + joint_mask=None, + ) + data = torch.cat([data1.unsqueeze(-1), data2.unsqueeze(-1)], dim=-1) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data, atol=2e-6, rtol=1e-6) + else: + # All envs and selected joints + data1 = torch.rand((num_instances, len(joint_ids)), device=device) + data2 = torch.rand((num_instances, len(joint_ids)), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[:, joint_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + data2_warp = torch.ones((num_instances, num_joints), device=device) + data2_warp[:, joint_ids] = data2 + data2_warp = wp.from_torch(data2_warp, dtype=wp.float32) + joint_mask = torch.zeros((num_joints,), dtype=torch.bool, device=device) + joint_mask[joint_ids] = True + joint_mask = wp.from_torch(joint_mask, dtype=wp.bool) + writer_function( + data1_warp, + data2_warp, + env_mask=None, + joint_mask=joint_mask, + ) + data = torch.cat([data1.unsqueeze(-1), data2.unsqueeze(-1)], dim=-1) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data)[:, joint_ids].allclose(data, atol=1e-6, rtol=1e-6) + else: + if joint_ids is None: + # Selected envs and all joints + data1 = torch.rand((len(env_ids), num_joints), device=device) + data2 = torch.rand((len(env_ids), num_joints), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[env_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + data2_warp = torch.ones((num_instances, num_joints), device=device) + data2_warp[env_ids] = data2 + data2_warp = wp.from_torch(data2_warp, dtype=wp.float32) + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + writer_function( + data1_warp, + data2_warp, + env_mask=env_mask, + joint_mask=None, + ) + data = torch.cat([data1.unsqueeze(-1), data2.unsqueeze(-1)], dim=-1) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + data1 = torch.rand((len(env_ids), len(joint_ids)), device=device) + data2 = torch.rand((len(env_ids), len(joint_ids)), device=device) + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[env_ids_, joint_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + data2_warp = torch.ones((num_instances, num_joints), device=device) + data2_warp[env_ids_, joint_ids] = data2 + data2_warp = wp.from_torch(data2_warp, dtype=wp.float32) + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + joint_mask = torch.zeros((num_joints,), dtype=torch.bool, device=device) + joint_mask[joint_ids] = True + joint_mask = wp.from_torch(joint_mask, dtype=wp.bool) + writer_function(data1_warp, data2_warp, env_mask=env_mask, joint_mask=joint_mask) + data = torch.cat([data1.unsqueeze(-1), data2.unsqueeze(-1)], dim=-1) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data)[env_ids_, joint_ids].allclose(data, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_stiffness_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_torch( + device, env_ids, joint_ids, num_instances, num_joints, "write_joint_stiffness_to_sim", "joint_stiffness" + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_stiffness_to_sim_warp( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_warp( + device, env_ids, joint_ids, num_instances, num_joints, "write_joint_stiffness_to_sim", "joint_stiffness" + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_damping_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_torch( + device, env_ids, joint_ids, num_instances, num_joints, "write_joint_damping_to_sim", "joint_damping" + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_damping_to_sim_warp( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_warp( + device, env_ids, joint_ids, num_instances, num_joints, "write_joint_damping_to_sim", "joint_damping" + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_velocity_limit_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_torch( + device, + env_ids, + joint_ids, + num_instances, + num_joints, + "write_joint_velocity_limit_to_sim", + "joint_vel_limits", + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_velocity_limit_to_sim_warp( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_warp( + device, + env_ids, + joint_ids, + num_instances, + num_joints, + "write_joint_velocity_limit_to_sim", + "joint_vel_limits", + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_effort_limit_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_torch( + device, + env_ids, + joint_ids, + num_instances, + num_joints, + "write_joint_effort_limit_to_sim", + "joint_effort_limits", + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_effort_limit_to_sim_warp( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_warp( + device, + env_ids, + joint_ids, + num_instances, + num_joints, + "write_joint_effort_limit_to_sim", + "joint_effort_limits", + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_armature_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_torch( + device, env_ids, joint_ids, num_instances, num_joints, "write_joint_armature_to_sim", "joint_armature" + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_armature_to_sim_warp( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_warp( + device, env_ids, joint_ids, num_instances, num_joints, "write_joint_armature_to_sim", "joint_armature" + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_friction_coefficient_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_torch( + device, + env_ids, + joint_ids, + num_instances, + num_joints, + "write_joint_friction_coefficient_to_sim", + "joint_friction_coeff", + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_friction_coefficient_to_sim_warp( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_warp( + device, + env_ids, + joint_ids, + num_instances, + num_joints, + "write_joint_friction_coefficient_to_sim", + "joint_friction_coeff", + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_dynamic_friction_coefficient_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_torch( + device, + env_ids, + joint_ids, + num_instances, + num_joints, + "write_joint_dynamic_friction_coefficient_to_sim", + "joint_dynamic_friction_coeff", + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_dynamic_friction_coefficient_to_sim_warp( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_warp( + device, + env_ids, + joint_ids, + num_instances, + num_joints, + "write_joint_dynamic_friction_coefficient_to_sim", + "joint_dynamic_friction_coeff", + ) + + # TODO: Remove once the deprecated function is removed. + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_friction_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_torch( + device, env_ids, joint_ids, num_instances, num_joints, "write_joint_friction_to_sim", "joint_friction_coeff" + ) + + # TODO: Remove once the deprecated function is removed. + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_friction_to_sim_warp( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_warp( + device, env_ids, joint_ids, num_instances, num_joints, "write_joint_friction_to_sim", "joint_friction_coeff" + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_position_limit_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_torch_dual( + device, + env_ids, + joint_ids, + num_instances, + num_joints, + "write_joint_position_limit_to_sim", + "joint_pos_limits", + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_position_limit_to_sim_warp_dual( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_warp_dual( + device, + env_ids, + joint_ids, + num_instances, + num_joints, + "write_joint_position_limit_to_sim", + "joint_pos_limits", + ) + + # TODO: Remove once the deprecated function is removed. + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_limits_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_torch_dual( + device, env_ids, joint_ids, num_instances, num_joints, "write_joint_limits_to_sim", "joint_pos_limits" + ) + + # TODO: Remove once the deprecated function is removed. + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_limits_to_sim_warp_dual( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_warp_dual( + device, env_ids, joint_ids, num_instances, num_joints, "write_joint_limits_to_sim", "joint_pos_limits" + ) + + +## +# Test Cases - Setters. +## + + +class TestSettersBodiesMassCoMInertia: + """Tests for setter methods that set body mass, center of mass, and inertia. + + Tests methods: + - set_masses + - set_coms + - set_inertias + """ + + def generic_test_property_writer_torch( + self, + device: str, + env_ids, + body_ids, + num_instances: int, + num_bodies: int, + writer_function_name: str, + property_name: str, + dtype: type = wp.float32, + ): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_bodies=num_bodies, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_bodies == 1: + if (body_ids is not None) and (not isinstance(body_ids, slice)): + body_ids = [0] + + writer_function = getattr(articulation, writer_function_name) + if dtype == wp.float32: + ndims = tuple() + elif dtype == wp.vec3f: + ndims = (3,) + elif dtype == wp.mat33f: + ndims = ( + 3, + 3, + ) + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + for _ in range(5): + if (env_ids is None) or (isinstance(env_ids, slice)): + if (body_ids is None) or (isinstance(body_ids, slice)): + # All envs and joints + data1 = torch.rand((num_instances, num_bodies, *ndims), device=device) + writer_function(data1, env_ids=env_ids, body_ids=body_ids) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data1, atol=1e-6, rtol=1e-6) + else: + # All envs and selected bodies + data1 = torch.rand((num_instances, len(body_ids), *ndims), device=device) + data_ref = torch.zeros((num_instances, num_bodies, *ndims), device=device) + data_ref[:, body_ids] = data1 + writer_function(data1, env_ids=env_ids, body_ids=body_ids) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + else: + if (body_ids is None) or (isinstance(body_ids, slice)): + # Selected envs and all bodies + data1 = torch.rand((len(env_ids), num_bodies, *ndims), device=device) + data_ref = torch.zeros((num_instances, num_bodies, *ndims), device=device) + data_ref[env_ids, :] = data1 + writer_function(data1, env_ids=env_ids, body_ids=body_ids) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + else: + # Selected envs and bodies + data1 = torch.rand((len(env_ids), len(body_ids), *ndims), device=device) + writer_function(data1, env_ids=env_ids, body_ids=body_ids) + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + data_ref = torch.zeros((num_instances, num_bodies, *ndims), device=device) + data_ref[env_ids_, body_ids] = data1 + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + + def generic_test_property_writer_warp( + self, + device: str, + env_ids, + body_ids, + num_instances: int, + num_bodies: int, + writer_function_name: str, + property_name: str, + dtype: type = wp.float32, + ): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_bodies=num_bodies, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_bodies == 1: + if (body_ids is not None) and (not isinstance(body_ids, slice)): + body_ids = [0] + + writer_function = getattr(articulation, writer_function_name) + if dtype == wp.float32: + ndims = tuple() + elif dtype == wp.vec3f: + ndims = (3,) + elif dtype == wp.mat33f: + ndims = ( + 3, + 3, + ) + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + for _ in range(5): + if env_ids is None: + if body_ids is None: + # All envs and joints + data1 = torch.rand((num_instances, num_bodies, *ndims), device=device) + data1_warp = wp.from_torch(data1, dtype=dtype) + writer_function(data1_warp, env_mask=None, body_mask=None) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data1, atol=1e-6, rtol=1e-6) + else: + # All envs and selected joints + data1 = torch.rand((num_instances, len(body_ids), *ndims), device=device) + data1_warp = torch.ones((num_instances, num_bodies, *ndims), device=device) + data1_warp[:, body_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=dtype) + body_mask = torch.zeros((num_bodies,), dtype=torch.bool, device=device) + body_mask[body_ids] = True + body_mask = wp.from_torch(body_mask, dtype=wp.bool) + data_ref = torch.zeros((num_instances, num_bodies, *ndims), device=device) + data_ref[:, body_ids] = data1 + writer_function(data1_warp, env_mask=None, body_mask=body_mask) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + else: + if body_ids is None: + # Selected envs and all joints + data1 = torch.rand((len(env_ids), num_bodies, *ndims), device=device) + data1_warp = torch.ones((num_instances, num_bodies, *ndims), device=device) + data1_warp[env_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=dtype) + data_ref = torch.zeros((num_instances, num_bodies, *ndims), device=device) + data_ref[env_ids, :] = data1 + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + writer_function(data1_warp, env_mask=env_mask, body_mask=None) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + data1 = torch.rand((len(env_ids), len(body_ids), *ndims), device=device) + data1_warp = torch.ones((num_instances, num_bodies, *ndims), device=device) + data1_warp[env_ids_, body_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=dtype) + data_ref = torch.zeros((num_instances, num_bodies, *ndims), device=device) + data_ref[env_ids_, body_ids] = data1 + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + body_mask = torch.zeros((num_bodies,), dtype=torch.bool, device=device) + body_mask[body_ids] = True + body_mask = wp.from_torch(body_mask, dtype=wp.bool) + writer_function(data1_warp, env_mask=env_mask, body_mask=body_mask) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("body_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_bodies", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_set_masses_to_sim_torch(self, device: str, env_ids, body_ids, num_instances: int, num_bodies: int): + self.generic_test_property_writer_torch( + device, env_ids, body_ids, num_instances, num_bodies, "set_masses", "body_mass", dtype=wp.float32 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("body_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_bodies", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_set_masses_to_sim_warp(self, device: str, env_ids, body_ids, num_instances: int, num_bodies: int): + self.generic_test_property_writer_warp( + device, env_ids, body_ids, num_instances, num_bodies, "set_masses", "body_mass", dtype=wp.float32 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("body_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_bodies", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_set_coms_to_sim_torch(self, device: str, env_ids, body_ids, num_instances: int, num_bodies: int): + self.generic_test_property_writer_torch( + device, env_ids, body_ids, num_instances, num_bodies, "set_coms", "body_com_pos_b", dtype=wp.vec3f + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("body_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_bodies", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_set_coms_to_sim_warp(self, device: str, env_ids, body_ids, num_instances: int, num_bodies: int): + self.generic_test_property_writer_warp( + device, env_ids, body_ids, num_instances, num_bodies, "set_coms", "body_com_pos_b", dtype=wp.vec3f + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("body_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_bodies", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_set_inertias_to_sim_torch(self, device: str, env_ids, body_ids, num_instances: int, num_bodies: int): + self.generic_test_property_writer_torch( + device, env_ids, body_ids, num_instances, num_bodies, "set_inertias", "body_inertia", dtype=wp.mat33f + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("body_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_bodies", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_set_inertias_to_sim_warp(self, device: str, env_ids, body_ids, num_instances: int, num_bodies: int): + self.generic_test_property_writer_warp( + device, env_ids, body_ids, num_instances, num_bodies, "set_inertias", "body_inertia", dtype=wp.mat33f + ) + + +# TODO: Implement these tests once the Wrench Composers made it to main IsaacLab. +class TestSettersExternalWrench: + """Tests for setter methods that set external wrench. + + Tests methods: + - set_external_force_and_torque + """ + + @pytest.mark.skip(reason="Not implemented") + def test_external_force_and_torque_to_sim_torch( + self, device: str, env_ids, body_ids, num_instances: int, num_bodies: int + ): + raise NotImplementedError() + + @pytest.mark.skip(reason="Not implemented") + def test_external_force_and_torque_to_sim_warp( + self, device: str, env_ids, body_ids, num_instances: int, num_bodies: int + ): + raise NotImplementedError() + + +class TestFixedTendonsSetters: + """Tests for setter methods that set fixed tendon properties. + + Tests methods: + - set_fixed_tendon_stiffness + - set_fixed_tendon_damping + - set_fixed_tendon_limit_stiffness + - set_fixed_tendon_position_limit + - set_fixed_tendon_limit (deprecated) + - set_fixed_tendon_rest_length + - set_fixed_tendon_offset + - write_fixed_tendon_properties_to_sim + """ + + def test_set_fixed_tendon_stiffness_not_implemented(self): + """Test that set_fixed_tendon_stiffness raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + stiffness = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_fixed_tendon_stiffness(stiffness) + + def test_set_fixed_tendon_damping_not_implemented(self): + """Test that set_fixed_tendon_damping raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + damping = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_fixed_tendon_damping(damping) + + def test_set_fixed_tendon_limit_stiffness_not_implemented(self): + """Test that set_fixed_tendon_limit_stiffness raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + limit_stiffness = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_fixed_tendon_limit_stiffness(limit_stiffness) + + def test_set_fixed_tendon_position_limit_not_implemented(self): + """Test that set_fixed_tendon_position_limit raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + limit = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_fixed_tendon_position_limit(limit) + + def test_set_fixed_tendon_limit_not_implemented(self): + """Test that set_fixed_tendon_limit (deprecated) raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + limit = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_fixed_tendon_limit(limit) + + def test_set_fixed_tendon_rest_length_not_implemented(self): + """Test that set_fixed_tendon_rest_length raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + rest_length = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_fixed_tendon_rest_length(rest_length) + + def test_set_fixed_tendon_offset_not_implemented(self): + """Test that set_fixed_tendon_offset raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + offset = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_fixed_tendon_offset(offset) + + def test_write_fixed_tendon_properties_to_sim_not_implemented(self): + """Test that write_fixed_tendon_properties_to_sim raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + with pytest.raises(NotImplementedError): + articulation.write_fixed_tendon_properties_to_sim() + + +class TestSpatialTendonsSetters: + """Tests for setter methods that set spatial tendon properties. + + Tests methods: + - set_spatial_tendon_stiffness + - set_spatial_tendon_damping + - set_spatial_tendon_limit_stiffness + - set_spatial_tendon_offset + - write_spatial_tendon_properties_to_sim + """ + + def test_set_spatial_tendon_stiffness_not_implemented(self): + """Test that set_spatial_tendon_stiffness raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + stiffness = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_spatial_tendon_stiffness(stiffness) + + def test_set_spatial_tendon_damping_not_implemented(self): + """Test that set_spatial_tendon_damping raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + damping = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_spatial_tendon_damping(damping) + + def test_set_spatial_tendon_limit_stiffness_not_implemented(self): + """Test that set_spatial_tendon_limit_stiffness raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + limit_stiffness = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_spatial_tendon_limit_stiffness(limit_stiffness) + + def test_set_spatial_tendon_offset_not_implemented(self): + """Test that set_spatial_tendon_offset raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + offset = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_spatial_tendon_offset(offset) + + def test_write_spatial_tendon_properties_to_sim_not_implemented(self): + """Test that write_spatial_tendon_properties_to_sim raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + with pytest.raises(NotImplementedError): + articulation.write_spatial_tendon_properties_to_sim() + + +class TestCreateBuffers: + """Tests for _create_buffers method. + + Tests that the buffers are created correctly: + - _ALL_INDICES tensor contains correct indices for varying number of environments + - soft_joint_pos_limits are correctly computed based on soft_joint_pos_limit_factor + """ + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("num_instances", [1, 2, 4, 10, 100]) + def test_create_buffers_all_indices(self, device: str, num_instances: int): + """Test that _ALL_INDICES contains correct indices for varying number of environments.""" + num_joints = 6 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set up joint limits (required for _create_buffers) + joint_limit_lower = torch.full((num_instances, num_joints), -1.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 1.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Call _create_buffers + articulation._create_buffers() + + # Verify _ALL_INDICES + expected_indices = torch.arange(num_instances, dtype=torch.long, device=device) + assert articulation._ALL_INDICES.shape == (num_instances,) + assert articulation._ALL_INDICES.dtype == torch.long + assert articulation._ALL_INDICES.device.type == device.split(":")[0] + torch.testing.assert_close(articulation._ALL_INDICES, expected_indices) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_create_buffers_soft_joint_limits_factor_1(self, device: str): + """Test soft_joint_pos_limits with factor=1.0 (limits unchanged).""" + num_instances = 2 + num_joints = 4 + soft_joint_pos_limit_factor = 1.0 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + soft_joint_pos_limit_factor=soft_joint_pos_limit_factor, + device=device, + ) + + # Set up joint limits: [-2.0, 2.0] for all joints + joint_limit_lower = torch.full((num_instances, num_joints), -2.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 2.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Call _create_buffers + articulation._create_buffers() + + # With factor=1.0, soft limits should equal hard limits + # soft_joint_pos_limits is wp.vec2f (lower, upper) + soft_limits = wp.to_torch(articulation.data.soft_joint_pos_limits) + # Shape is (num_instances, num_joints, 2) after conversion + expected_lower = torch.full((num_instances, num_joints), -2.0, device=device) + expected_upper = torch.full((num_instances, num_joints), 2.0, device=device) + torch.testing.assert_close(soft_limits[:, :, 0], expected_lower, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(soft_limits[:, :, 1], expected_upper, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_create_buffers_soft_joint_limits_factor_half(self, device: str): + """Test soft_joint_pos_limits with factor=0.5 (limits halved around mean).""" + num_instances = 2 + num_joints = 4 + soft_joint_pos_limit_factor = 0.5 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + soft_joint_pos_limit_factor=soft_joint_pos_limit_factor, + device=device, + ) + + # Set up joint limits: [-2.0, 2.0] for all joints + # mean = 0.0, range = 4.0 + # soft_lower = 0.0 - 0.5 * 4.0 * 0.5 = -1.0 + # soft_upper = 0.0 + 0.5 * 4.0 * 0.5 = 1.0 + joint_limit_lower = torch.full((num_instances, num_joints), -2.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 2.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Call _create_buffers + articulation._create_buffers() + + # Verify soft limits are halved + soft_limits = wp.to_torch(articulation.data.soft_joint_pos_limits) + expected_lower = torch.full((num_instances, num_joints), -1.0, device=device) + expected_upper = torch.full((num_instances, num_joints), 1.0, device=device) + torch.testing.assert_close(soft_limits[:, :, 0], expected_lower, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(soft_limits[:, :, 1], expected_upper, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_create_buffers_soft_joint_limits_asymmetric(self, device: str): + """Test soft_joint_pos_limits with asymmetric joint limits.""" + num_instances = 2 + num_joints = 3 + soft_joint_pos_limit_factor = 0.8 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + soft_joint_pos_limit_factor=soft_joint_pos_limit_factor, + device=device, + ) + + # Set up asymmetric joint limits + # Joint 0: [-3.14, 3.14] -> mean=0, range=6.28 -> soft: [-2.512, 2.512] + # Joint 1: [-1.0, 2.0] -> mean=0.5, range=3.0 -> soft: [0.5-1.2, 0.5+1.2] = [-0.7, 1.7] + # Joint 2: [0.0, 1.0] -> mean=0.5, range=1.0 -> soft: [0.5-0.4, 0.5+0.4] = [0.1, 0.9] + joint_limit_lower = torch.tensor([[-3.14, -1.0, 0.0]] * num_instances, device=device) + joint_limit_upper = torch.tensor([[3.14, 2.0, 1.0]] * num_instances, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Call _create_buffers + articulation._create_buffers() + + # Calculate expected soft limits + # soft_lower = mean - 0.5 * range * factor + # soft_upper = mean + 0.5 * range * factor + expected_lower = torch.tensor([[-2.512, -0.7, 0.1]] * num_instances, device=device) + expected_upper = torch.tensor([[2.512, 1.7, 0.9]] * num_instances, device=device) + + soft_limits = wp.to_torch(articulation.data.soft_joint_pos_limits) + torch.testing.assert_close(soft_limits[:, :, 0], expected_lower, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(soft_limits[:, :, 1], expected_upper, atol=1e-3, rtol=1e-3) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_create_buffers_soft_joint_limits_factor_zero(self, device: str): + """Test soft_joint_pos_limits with factor=0.0 (limits collapse to mean).""" + num_instances = 2 + num_joints = 4 + soft_joint_pos_limit_factor = 0.0 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + soft_joint_pos_limit_factor=soft_joint_pos_limit_factor, + device=device, + ) + + # Set up joint limits: [-2.0, 2.0] + # mean = 0.0, with factor=0.0, soft limits collapse to [0, 0] + joint_limit_lower = torch.full((num_instances, num_joints), -2.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 2.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Call _create_buffers + articulation._create_buffers() + + # With factor=0.0, soft limits should collapse to the mean + soft_limits = wp.to_torch(articulation.data.soft_joint_pos_limits) + expected_lower = torch.full((num_instances, num_joints), 0.0, device=device) + expected_upper = torch.full((num_instances, num_joints), 0.0, device=device) + torch.testing.assert_close(soft_limits[:, :, 0], expected_lower, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(soft_limits[:, :, 1], expected_upper, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_create_buffers_soft_joint_limits_per_joint_different(self, device: str): + """Test soft_joint_pos_limits with different limits per joint.""" + num_instances = 3 + num_joints = 4 + soft_joint_pos_limit_factor = 0.9 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + soft_joint_pos_limit_factor=soft_joint_pos_limit_factor, + device=device, + ) + + # Each joint has different limits + joint_limit_lower = torch.tensor([[-1.0, -2.0, -0.5, -3.0]] * num_instances, device=device) + joint_limit_upper = torch.tensor([[1.0, 2.0, 0.5, 3.0]] * num_instances, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Call _create_buffers + articulation._create_buffers() + + # Calculate expected: soft_lower/upper = mean ± 0.5 * range * factor + # Joint 0: mean=0, range=2 -> [0 - 0.9, 0 + 0.9] = [-0.9, 0.9] + # Joint 1: mean=0, range=4 -> [0 - 1.8, 0 + 1.8] = [-1.8, 1.8] + # Joint 2: mean=0, range=1 -> [0 - 0.45, 0 + 0.45] = [-0.45, 0.45] + # Joint 3: mean=0, range=6 -> [0 - 2.7, 0 + 2.7] = [-2.7, 2.7] + expected_lower = torch.tensor([[-0.9, -1.8, -0.45, -2.7]] * num_instances, device=device) + expected_upper = torch.tensor([[0.9, 1.8, 0.45, 2.7]] * num_instances, device=device) + + soft_limits = wp.to_torch(articulation.data.soft_joint_pos_limits) + torch.testing.assert_close(soft_limits[:, :, 0], expected_lower, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(soft_limits[:, :, 1], expected_upper, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_create_buffers_single_environment(self, device: str): + """Test _create_buffers with a single environment.""" + num_instances = 1 + num_joints = 6 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + joint_limit_lower = torch.full((num_instances, num_joints), -1.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 1.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Call _create_buffers + articulation._create_buffers() + + # Verify _ALL_INDICES has single element + assert articulation._ALL_INDICES.shape == (1,) + assert articulation._ALL_INDICES[0].item() == 0 + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_create_buffers_large_number_of_environments(self, device: str): + """Test _create_buffers with a large number of environments.""" + num_instances = 1024 + num_joints = 12 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + joint_limit_lower = torch.full((num_instances, num_joints), -1.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 1.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Call _create_buffers + articulation._create_buffers() + + # Verify _ALL_INDICES + expected_indices = torch.arange(num_instances, dtype=torch.long, device=device) + assert articulation._ALL_INDICES.shape == (num_instances,) + torch.testing.assert_close(articulation._ALL_INDICES, expected_indices) + + # Verify soft limits shape + soft_limits = wp.to_torch(articulation.data.soft_joint_pos_limits) + assert soft_limits.shape == (num_instances, num_joints, 2) + + +class TestProcessCfg: + """Tests for _process_cfg method. + + Tests that the configuration processing correctly: + - Converts quaternion from (w, x, y, z) to (x, y, z, w) format for default root pose + - Sets default root velocity from lin_vel and ang_vel + - Sets default joint positions from joint_pos dict with pattern matching + - Sets default joint velocities from joint_vel dict with pattern matching + """ + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_default_root_pose(self, device: str): + """Test that _process_cfg correctly converts quaternion format for root pose.""" + num_instances = 2 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set up init_state with specific position and rotation + # Rotation is in (w, x, y, z) format in the config + articulation.cfg.init_state.pos = (1.0, 2.0, 3.0) + articulation.cfg.init_state.rot = (0.707, 0.0, 0.707, 0.0) # w, x, y, z + + # Call _process_cfg + articulation._process_cfg() + + # Verify the default root pose + # Expected: position (1, 2, 3) + quaternion converted to (x, y, z, w) = (0, 0.707, 0, 0.707) + expected_pose = torch.tensor( + [[1.0, 2.0, 3.0, 0.0, 0.707, 0.0, 0.707]] * num_instances, + device=device, + ) + result = wp.to_torch(articulation.data.default_root_pose) + assert result.allclose(expected_pose, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_default_root_velocity(self, device: str): + """Test that _process_cfg correctly sets default root velocity.""" + num_instances = 2 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set up init_state with specific velocities + articulation.cfg.init_state.lin_vel = (1.0, 2.0, 3.0) + articulation.cfg.init_state.ang_vel = (0.1, 0.2, 0.3) + + # Call _process_cfg + articulation._process_cfg() + + # Verify the default root velocity + # Expected: lin_vel + ang_vel = (1, 2, 3, 0.1, 0.2, 0.3) + expected_vel = torch.tensor( + [[1.0, 2.0, 3.0, 0.1, 0.2, 0.3]] * num_instances, + device=device, + ) + result = wp.to_torch(articulation.data.default_root_vel) + assert result.allclose(expected_vel, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_default_joint_positions_all_joints(self, device: str): + """Test that _process_cfg correctly sets default joint positions for all joints.""" + num_instances = 2 + num_joints = 4 + joint_names = ["joint_0", "joint_1", "joint_2", "joint_3"] + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + joint_names=joint_names, + device=device, + ) + + # Set up init_state with joint positions using wildcard pattern + articulation.cfg.init_state.joint_pos = {".*": 0.5} + articulation.cfg.init_state.joint_vel = {".*": 0.0} + + # Call _process_cfg + articulation._process_cfg() + + # Verify the default joint positions + expected_pos = torch.full((num_instances, num_joints), 0.5, device=device) + result = wp.to_torch(articulation.data.default_joint_pos) + assert result.allclose(expected_pos, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_default_joint_positions_specific_joints(self, device: str): + """Test that _process_cfg correctly sets default joint positions for specific joints.""" + num_instances = 2 + num_joints = 4 + joint_names = ["shoulder", "elbow", "wrist", "gripper"] + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + joint_names=joint_names, + device=device, + ) + + # Set up init_state with specific joint positions + articulation.cfg.init_state.joint_pos = { + "shoulder": 1.0, + "elbow": 2.0, + "wrist": 3.0, + "gripper": 4.0, + } + articulation.cfg.init_state.joint_vel = {".*": 0.0} + + # Call _process_cfg + articulation._process_cfg() + + # Verify the default joint positions + expected_pos = torch.tensor([[1.0, 2.0, 3.0, 4.0]] * num_instances, device=device) + result = wp.to_torch(articulation.data.default_joint_pos) + assert result.allclose(expected_pos, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_default_joint_positions_regex_pattern(self, device: str): + """Test that _process_cfg correctly handles regex patterns for joint positions.""" + num_instances = 2 + num_joints = 6 + joint_names = ["arm_joint_1", "arm_joint_2", "arm_joint_3", "hand_joint_1", "hand_joint_2", "hand_joint_3"] + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + joint_names=joint_names, + device=device, + ) + + # Set up init_state with regex patterns + articulation.cfg.init_state.joint_pos = { + "arm_joint_.*": 1.5, + "hand_joint_.*": 0.5, + } + articulation.cfg.init_state.joint_vel = {".*": 0.0} + + # Call _process_cfg + articulation._process_cfg() + + # Verify the default joint positions + # arm joints (indices 0-2) should be 1.5, hand joints (indices 3-5) should be 0.5 + expected_pos = torch.tensor([[1.5, 1.5, 1.5, 0.5, 0.5, 0.5]] * num_instances, device=device) + result = wp.to_torch(articulation.data.default_joint_pos) + assert result.allclose(expected_pos, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_default_joint_velocities(self, device: str): + """Test that _process_cfg correctly sets default joint velocities.""" + num_instances = 2 + num_joints = 4 + joint_names = ["joint_0", "joint_1", "joint_2", "joint_3"] + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + joint_names=joint_names, + device=device, + ) + + # Set up init_state with joint velocities + articulation.cfg.init_state.joint_pos = {".*": 0.0} + articulation.cfg.init_state.joint_vel = { + "joint_0": 0.1, + "joint_1": 0.2, + "joint_2": 0.3, + "joint_3": 0.4, + } + + # Call _process_cfg + articulation._process_cfg() + + # Verify the default joint velocities + expected_vel = torch.tensor([[0.1, 0.2, 0.3, 0.4]] * num_instances, device=device) + result = wp.to_torch(articulation.data.default_joint_vel) + assert result.allclose(expected_vel, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_identity_quaternion(self, device: str): + """Test that _process_cfg correctly handles identity quaternion.""" + num_instances = 2 + num_joints = 2 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set up init_state with identity quaternion (w=1, x=0, y=0, z=0) + articulation.cfg.init_state.pos = (0.0, 0.0, 0.0) + articulation.cfg.init_state.rot = (1.0, 0.0, 0.0, 0.0) # Identity: w, x, y, z + + # Call _process_cfg + articulation._process_cfg() + + # Verify the default root pose + # Expected: position (0, 0, 0) + quaternion converted to (x, y, z, w) = (0, 0, 0, 1) + expected_pose = torch.tensor( + [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]] * num_instances, + device=device, + ) + result = wp.to_torch(articulation.data.default_root_pose) + assert result.allclose(expected_pose, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_zero_joints(self, device: str): + """Test that _process_cfg handles articulation with no joints.""" + num_instances = 2 + num_joints = 0 + num_bodies = 1 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + num_bodies=num_bodies, + device=device, + ) + + # Set up init_state + articulation.cfg.init_state.pos = (1.0, 2.0, 3.0) + articulation.cfg.init_state.rot = (1.0, 0.0, 0.0, 0.0) + articulation.cfg.init_state.lin_vel = (0.5, 0.5, 0.5) + articulation.cfg.init_state.ang_vel = (0.1, 0.1, 0.1) + articulation.cfg.init_state.joint_pos = {} + articulation.cfg.init_state.joint_vel = {} + + # Call _process_cfg - should not raise any exception + articulation._process_cfg() + + # Verify root pose and velocity are still set correctly + expected_pose = torch.tensor( + [[1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 1.0]] * num_instances, + device=device, + ) + expected_vel = torch.tensor( + [[0.5, 0.5, 0.5, 0.1, 0.1, 0.1]] * num_instances, + device=device, + ) + assert wp.to_torch(articulation.data.default_root_pose).allclose(expected_pose, atol=1e-5, rtol=1e-5) + assert wp.to_torch(articulation.data.default_root_vel).allclose(expected_vel, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_mixed_joint_patterns(self, device: str): + """Test that _process_cfg handles mixed specific and pattern-based joint settings.""" + num_instances = 2 + num_joints = 5 + joint_names = ["base_joint", "arm_1", "arm_2", "hand_1", "hand_2"] + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + joint_names=joint_names, + device=device, + ) + + # Set up init_state with mixed patterns + articulation.cfg.init_state.joint_pos = { + "base_joint": 0.0, + "arm_.*": 1.0, + "hand_.*": 2.0, + } + articulation.cfg.init_state.joint_vel = {".*": 0.0} + + # Call _process_cfg + articulation._process_cfg() + + # Verify the default joint positions + expected_pos = torch.tensor([[0.0, 1.0, 1.0, 2.0, 2.0]] * num_instances, device=device) + result = wp.to_torch(articulation.data.default_joint_pos) + assert result.allclose(expected_pos, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_offsets_spawned_pose(self, device: str): + """Test that _process_cfg offsets the spawned position by the default root pose.""" + num_instances = 3 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set up default root pose in config: position (1.0, 2.0, 3.0), identity quaternion + articulation.cfg.init_state.pos = (1.0, 2.0, 3.0) + articulation.cfg.init_state.rot = (1.0, 0.0, 0.0, 0.0) # w, x, y, z (identity) + + # Set up initial spawned positions for each instance + # Instance 0: (5.0, 6.0, 0.0) + # Instance 1: (10.0, 20.0, 0.0) + # Instance 2: (-3.0, -4.0, 0.0) + spawned_transforms = torch.tensor( + [ + [5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 1.0], # pos (x,y,z), quat (x,y,z,w) + [10.0, 20.0, 0.0, 0.0, 0.0, 0.0, 1.0], + [-3.0, -4.0, 0.0, 0.0, 0.0, 0.0, 1.0], + ], + device=device, + ) + mock_view.set_mock_data( + root_transforms=wp.from_torch(spawned_transforms, dtype=wp.transformf), + ) + + # Call _process_cfg + articulation._process_cfg() + + # Verify that the root transforms are offset by default pose's x,y + # Expected: spawned_pose[:, :2] + default_pose[:2] + # Instance 0: (5.0 + 1.0, 6.0 + 2.0, 3.0) = (6.0, 8.0, 3.0) + # Instance 1: (10.0 + 1.0, 20.0 + 2.0, 3.0) = (11.0, 22.0, 3.0) + # Instance 2: (-3.0 + 1.0, -4.0 + 2.0, 3.0) = (-2.0, -2.0, 3.0) + result = wp.to_torch(mock_view.get_root_transforms(None)) + expected_transforms = torch.tensor( + [ + [6.0, 8.0, 3.0, 0.0, 0.0, 0.0, 1.0], + [11.0, 22.0, 3.0, 0.0, 0.0, 0.0, 1.0], + [-2.0, -2.0, 3.0, 0.0, 0.0, 0.0, 1.0], + ], + device=device, + ) + torch.testing.assert_close(result, expected_transforms, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_offsets_spawned_pose_zero_offset(self, device: str): + """Test that _process_cfg with zero default position keeps spawned position unchanged in x,y.""" + num_instances = 2 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set up default root pose with zero position + articulation.cfg.init_state.pos = (0.0, 0.0, 0.0) + articulation.cfg.init_state.rot = (1.0, 0.0, 0.0, 0.0) + + # Set up initial spawned positions + spawned_transforms = torch.tensor( + [ + [5.0, 6.0, 7.0, 0.0, 0.0, 0.0, 1.0], + [10.0, 20.0, 30.0, 0.0, 0.0, 0.0, 1.0], + ], + device=device, + ) + mock_view.set_mock_data( + root_transforms=wp.from_torch(spawned_transforms, dtype=wp.transformf), + ) + + # Call _process_cfg + articulation._process_cfg() + + # With zero default position, x,y should stay the same, z comes from default (0.0) + result = wp.to_torch(mock_view.get_root_transforms(None)) + expected_transforms = torch.tensor( + [ + [5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 1.0], + [10.0, 20.0, 0.0, 0.0, 0.0, 0.0, 1.0], + ], + device=device, + ) + torch.testing.assert_close(result, expected_transforms, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_offsets_spawned_pose_with_rotation(self, device: str): + """Test that _process_cfg correctly sets rotation while offsetting position.""" + num_instances = 2 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set up default root pose with specific rotation (90 degrees around z-axis) + # Quaternion for 90 degrees around z: (w=0.707, x=0, y=0, z=0.707) + articulation.cfg.init_state.pos = (1.0, 2.0, 5.0) + articulation.cfg.init_state.rot = (0.707, 0.0, 0.0, 0.707) # w, x, y, z + + # Set up initial spawned positions + spawned_transforms = torch.tensor( + [ + [3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 1.0], + [6.0, 8.0, 0.0, 0.0, 0.0, 0.0, 1.0], + ], + device=device, + ) + mock_view.set_mock_data( + root_transforms=wp.from_torch(spawned_transforms, dtype=wp.transformf), + ) + + # Call _process_cfg + articulation._process_cfg() + + # Verify position offset and rotation is set correctly + # Position: spawned[:2] + default[:2], z from default + # Rotation: from default (converted to x,y,z,w format) + result = wp.to_torch(mock_view.get_root_transforms(None)) + expected_transforms = torch.tensor( + [ + [4.0, 6.0, 5.0, 0.0, 0.0, 0.707, 0.707], # x,y,z, qx,qy,qz,qw + [7.0, 10.0, 5.0, 0.0, 0.0, 0.707, 0.707], + ], + device=device, + ) + torch.testing.assert_close(result, expected_transforms, atol=1e-3, rtol=1e-3) + + +class TestValidateCfg: + """Tests for _validate_cfg method. + + Tests that the configuration validation correctly catches: + - Default joint positions outside of joint limits (lower and upper bounds) + - Various edge cases with joint limits + """ + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_validate_cfg_positions_within_limits(self, device: str): + """Test that _validate_cfg passes when all default positions are within limits.""" + num_instances = 2 + num_joints = 6 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set joint limits: [-1.0, 1.0] for all joints + joint_limit_lower = torch.full((num_instances, num_joints), -1.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 1.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Set default joint positions within limits + default_joint_pos = torch.zeros((num_instances, num_joints), device=device) + articulation.data._default_joint_pos = wp.from_torch(default_joint_pos, dtype=wp.float32) + + # Should not raise any exception + articulation._validate_cfg() + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_validate_cfg_position_below_lower_limit(self, device: str): + """Test that _validate_cfg raises ValueError when a position is below the lower limit.""" + num_instances = 2 + num_joints = 6 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set joint limits: [-1.0, 1.0] for all joints + joint_limit_lower = torch.full((num_instances, num_joints), -1.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 1.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Set default joint position for joint 2 below the lower limit + default_joint_pos = torch.zeros((num_instances, num_joints), device=device) + default_joint_pos[:, 2] = -1.5 # Below -1.0 lower limit + articulation.data._default_joint_pos = wp.from_torch(default_joint_pos, dtype=wp.float32) + + # Should raise ValueError + with pytest.raises(ValueError) as exc_info: + articulation._validate_cfg() + assert "joint_2" in str(exc_info.value) + assert "-1.500" in str(exc_info.value) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_validate_cfg_position_above_upper_limit(self, device: str): + """Test that _validate_cfg raises ValueError when a position is above the upper limit.""" + num_instances = 2 + num_joints = 6 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set joint limits: [-1.0, 1.0] for all joints + joint_limit_lower = torch.full((num_instances, num_joints), -1.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 1.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Set default joint position for joint 4 above the upper limit + default_joint_pos = torch.zeros((num_instances, num_joints), device=device) + default_joint_pos[:, 4] = 1.5 # Above 1.0 upper limit + articulation.data._default_joint_pos = wp.from_torch(default_joint_pos, dtype=wp.float32) + + # Should raise ValueError + with pytest.raises(ValueError) as exc_info: + articulation._validate_cfg() + assert "joint_4" in str(exc_info.value) + assert "1.500" in str(exc_info.value) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_validate_cfg_multiple_positions_out_of_limits(self, device: str): + """Test that _validate_cfg reports all joints with positions outside limits.""" + num_instances = 2 + num_joints = 6 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set joint limits: [-1.0, 1.0] for all joints + joint_limit_lower = torch.full((num_instances, num_joints), -1.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 1.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Set multiple joints out of limits + default_joint_pos = torch.zeros((num_instances, num_joints), device=device) + default_joint_pos[:, 0] = -2.0 # Below lower limit + default_joint_pos[:, 3] = 2.0 # Above upper limit + default_joint_pos[:, 5] = -1.5 # Below lower limit + articulation.data._default_joint_pos = wp.from_torch(default_joint_pos, dtype=wp.float32) + + # Should raise ValueError mentioning all violated joints + with pytest.raises(ValueError) as exc_info: + articulation._validate_cfg() + error_msg = str(exc_info.value) + assert "joint_0" in error_msg + assert "joint_3" in error_msg + assert "joint_5" in error_msg + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_validate_cfg_asymmetric_limits(self, device: str): + """Test that _validate_cfg works with asymmetric joint limits.""" + num_instances = 2 + num_joints = 4 + joint_names = ["shoulder", "elbow", "wrist", "gripper"] + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + joint_names=joint_names, + device=device, + ) + + # Set asymmetric joint limits for each joint + joint_limit_lower = torch.tensor([[-3.14, -2.0, -1.5, 0.0]] * num_instances, device=device) + joint_limit_upper = torch.tensor([[3.14, 0.5, 1.5, 0.1]] * num_instances, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Set positions within asymmetric limits + default_joint_pos = torch.tensor([[0.0, -1.0, 0.0, 0.05]] * num_instances, device=device) + articulation.data._default_joint_pos = wp.from_torch(default_joint_pos, dtype=wp.float32) + + # Should not raise any exception + articulation._validate_cfg() + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_validate_cfg_asymmetric_limits_violated(self, device: str): + """Test that _validate_cfg catches violations with asymmetric limits.""" + num_instances = 2 + num_joints = 4 + joint_names = ["shoulder", "elbow", "wrist", "gripper"] + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + joint_names=joint_names, + device=device, + ) + + # Set asymmetric joint limits: elbow has range [-2.0, 0.5] + joint_limit_lower = torch.tensor([[-3.14, -2.0, -1.5, 0.0]] * num_instances, device=device) + joint_limit_upper = torch.tensor([[3.14, 0.5, 1.5, 0.1]] * num_instances, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Set elbow position above its upper limit (0.5) + default_joint_pos = torch.tensor([[0.0, 1.0, 0.0, 0.05]] * num_instances, device=device) + articulation.data._default_joint_pos = wp.from_torch(default_joint_pos, dtype=wp.float32) + + # Should raise ValueError for elbow + with pytest.raises(ValueError) as exc_info: + articulation._validate_cfg() + assert "elbow" in str(exc_info.value) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_validate_cfg_single_joint(self, device: str): + """Test _validate_cfg with a single joint articulation.""" + num_instances = 2 + num_joints = 1 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set joint limits + joint_limit_lower = torch.full((num_instances, num_joints), -0.5, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 0.5, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Set position outside limits + default_joint_pos = torch.full((num_instances, num_joints), 1.0, device=device) + articulation.data._default_joint_pos = wp.from_torch(default_joint_pos, dtype=wp.float32) + + # Should raise ValueError + with pytest.raises(ValueError) as exc_info: + articulation._validate_cfg() + assert "joint_0" in str(exc_info.value) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_validate_cfg_negative_range_limits(self, device: str): + """Test _validate_cfg with limits entirely in the negative range.""" + num_instances = 2 + num_joints = 2 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set limits entirely in negative range + joint_limit_lower = torch.full((num_instances, num_joints), -5.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), -2.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Set position at zero (outside negative-only limits) + default_joint_pos = torch.zeros((num_instances, num_joints), device=device) + articulation.data._default_joint_pos = wp.from_torch(default_joint_pos, dtype=wp.float32) + + # Should raise ValueError + with pytest.raises(ValueError) as exc_info: + articulation._validate_cfg() + # Both joints should be reported as violated + assert "joint_0" in str(exc_info.value) + assert "joint_1" in str(exc_info.value) + + +# TODO: Expand these tests when tendons are available in Newton. +# Currently, tendons are not implemented and _process_tendons only initializes empty lists. +# When tendon support is added, tests should verify: +# - Fixed tendon properties are correctly parsed and stored +# - Spatial tendon properties are correctly parsed and stored +# - Tendon limits and stiffness values are correctly set +class TestProcessTendons: + """Tests for _process_tendons method. + + Note: Tendons are not yet implemented in Newton. These tests verify the current + placeholder behavior. When tendons are implemented, these tests should be expanded. + """ + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_tendons_initializes_empty_lists(self, device: str): + """Test that _process_tendons initializes empty tendon name lists.""" + num_instances = 2 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Call _process_tendons + articulation._process_tendons() + + # Verify empty lists are created + assert articulation._fixed_tendon_names == [] + assert articulation._spatial_tendon_names == [] + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_tendons_returns_none(self, device: str): + """Test that _process_tendons returns None (no tendons implemented).""" + num_instances = 2 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Call _process_tendons and verify return value + result = articulation._process_tendons() + assert result is None + + +# TODO: Expand these tests when actuator mocking is more mature. +# Full actuator integration tests would require: +# - Mocking ActuatorBaseCfg and ActuatorBase classes +# - Testing implicit vs explicit actuator behavior +# - Testing stiffness/damping propagation +# Currently, we test the initialization behavior without actuators configured. +class TestProcessActuatorsCfg: + """Tests for _process_actuators_cfg method. + + Note: These tests focus on the initialization behavior when no actuators are configured. + Full actuator integration tests require additional mocking infrastructure. + """ + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_actuators_cfg_initializes_empty_dict(self, device: str): + """Test that _process_actuators_cfg initializes actuators as empty dict when none configured.""" + num_instances = 2 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Ensure no actuators are configured + articulation.cfg.actuators = {} + + # Call _process_actuators_cfg + articulation._process_actuators_cfg() + + # Verify actuators dict is empty + assert articulation.actuators == {} + assert isinstance(articulation.actuators, dict) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_actuators_cfg_sets_implicit_flag_false(self, device: str): + """Test that _process_actuators_cfg sets _has_implicit_actuators to False initially.""" + num_instances = 2 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + articulation.cfg.actuators = {} + + # Call _process_actuators_cfg + articulation._process_actuators_cfg() + + # Verify flag is set to False + assert articulation._has_implicit_actuators is False + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_actuators_cfg_sets_joint_limit_gains(self, device: str): + """Test that _process_actuators_cfg sets joint_limit_ke and joint_limit_kd.""" + num_instances = 2 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + articulation.cfg.actuators = {} + + # Call _process_actuators_cfg + articulation._process_actuators_cfg() + + # Verify joint limit gains are set + joint_limit_ke = wp.to_torch(mock_view.get_attribute("joint_limit_ke", None)) + joint_limit_kd = wp.to_torch(mock_view.get_attribute("joint_limit_kd", None)) + + expected_ke = torch.full((num_instances, num_joints), 2500.0, device=device) + expected_kd = torch.full((num_instances, num_joints), 100.0, device=device) + + torch.testing.assert_close(joint_limit_ke, expected_ke, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(joint_limit_kd, expected_kd, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_actuators_cfg_warns_unactuated_joints(self, device: str): + """Test that _process_actuators_cfg warns when not all joints have actuators.""" + num_instances = 2 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # No actuators configured but we have joints + articulation.cfg.actuators = {} + + # Should warn about unactuated joints + with pytest.warns(UserWarning, match="Not all actuators are configured"): + articulation._process_actuators_cfg() + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_actuators_cfg_no_warning_zero_joints(self, device: str): + """Test that _process_actuators_cfg does not warn when there are no joints.""" + num_instances = 2 + num_joints = 0 + num_bodies = 1 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + num_bodies=num_bodies, + device=device, + ) + + articulation.cfg.actuators = {} + + # Should not warn when there are no joints to actuate + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("error") + # This should not raise a warning + articulation._process_actuators_cfg() + + +## +# Main +## + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/source/isaaclab_newton/test/assets/articulation/test_articulation_data.py b/source/isaaclab_newton/test/assets/articulation/test_articulation_data.py new file mode 100644 index 00000000000..ba6bc7f5641 --- /dev/null +++ b/source/isaaclab_newton/test/assets/articulation/test_articulation_data.py @@ -0,0 +1,3292 @@ +# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for ArticulationData class comparing Newton implementation against PhysX reference.""" + +from __future__ import annotations + +import torch +from unittest.mock import MagicMock, patch + +import pytest +import warp as wp +from isaaclab_newton.assets.articulation.articulation_data import ArticulationData + +# TODO: Remove this import +from isaaclab.utils import math as math_utils + +# Import mock classes from shared module +from .mock_interface import MockNewtonArticulationView, MockNewtonModel + +# Initialize Warp +wp.init() + + +## +# Test Fixtures +## + + +@pytest.fixture +def mock_newton_manager(): + """Create mock NewtonManager with necessary methods.""" + mock_model = MockNewtonModel() + mock_state = MagicMock() + mock_control = MagicMock() + + # Patch where NewtonManager is used (in the articulation_data module) + with patch("isaaclab_newton.assets.articulation.articulation_data.NewtonManager") as MockManager: + MockManager.get_model.return_value = mock_model + MockManager.get_state_0.return_value = mock_state + MockManager.get_control.return_value = mock_control + MockManager.get_dt.return_value = 0.01 + yield MockManager + + +## +# Test Cases -- Defaults. +## + + +class TestDefaults: + """Tests the following properties: + - default_root_pose + - default_root_vel + - default_joint_pos + - default_joint_vel + + Runs the following checks: + - Checks that by default, the properties are all zero. + - Checks that the properties are settable. + - Checks that once the articulation data is primed, the properties cannot be changed. + """ + + def _setup_method(self, num_instances: int, num_dofs: int, device: str) -> ArticulationData: + mock_view = MockNewtonArticulationView(num_instances, 1, num_dofs, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + + return articulation_data + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_zero_instantiated(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test zero instantiated articulation data.""" + # Setup the articulation data + articulation_data = self._setup_method(num_instances, num_dofs, device) + # Check the types are correct + assert articulation_data.default_root_pose.dtype is wp.transformf + assert articulation_data.default_root_vel.dtype is wp.spatial_vectorf + assert articulation_data.default_joint_pos.dtype is wp.float32 + assert articulation_data.default_joint_vel.dtype is wp.float32 + # Check the shapes are correct + assert articulation_data.default_root_pose.shape == (num_instances,) + assert articulation_data.default_root_vel.shape == (num_instances,) + assert articulation_data.default_joint_pos.shape == (num_instances, num_dofs) + assert articulation_data.default_joint_vel.shape == (num_instances, num_dofs) + # Check the values are zero + assert torch.all( + wp.to_torch(articulation_data.default_root_pose) == torch.zeros(num_instances, 7, device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.default_root_vel) == torch.zeros(num_instances, 6, device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.default_joint_pos) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.default_joint_vel) == torch.zeros((num_instances, num_dofs), device=device) + ) + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_settable(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test that the articulation data is settable.""" + # Setup the articulation data + articulation_data = self._setup_method(num_instances, num_dofs, device) + # Set the default values + articulation_data.default_root_pose = wp.ones(num_instances, dtype=wp.transformf, device=device) + articulation_data.default_root_vel = wp.ones(num_instances, dtype=wp.spatial_vectorf, device=device) + articulation_data.default_joint_pos = wp.ones((num_instances, num_dofs), dtype=wp.float32, device=device) + articulation_data.default_joint_vel = wp.ones((num_instances, num_dofs), dtype=wp.float32, device=device) + # Check the types are correct + assert articulation_data.default_root_pose.dtype is wp.transformf + assert articulation_data.default_root_vel.dtype is wp.spatial_vectorf + assert articulation_data.default_joint_pos.dtype is wp.float32 + assert articulation_data.default_joint_vel.dtype is wp.float32 + # Check the shapes are correct + assert articulation_data.default_root_pose.shape == (num_instances,) + assert articulation_data.default_root_vel.shape == (num_instances,) + assert articulation_data.default_joint_pos.shape == (num_instances, num_dofs) + assert articulation_data.default_joint_vel.shape == (num_instances, num_dofs) + # Check the values are set + assert torch.all( + wp.to_torch(articulation_data.default_root_pose) == torch.ones(num_instances, 7, device=device) + ) + assert torch.all(wp.to_torch(articulation_data.default_root_vel) == torch.ones(num_instances, 6, device=device)) + assert torch.all( + wp.to_torch(articulation_data.default_joint_pos) == torch.ones((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.default_joint_vel) == torch.ones((num_instances, num_dofs), device=device) + ) + # Prime the articulation data + articulation_data.is_primed = True + # Check that the values cannot be changed + with pytest.raises(RuntimeError): + articulation_data.default_root_pose = wp.zeros(num_instances, dtype=wp.transformf, device=device) + with pytest.raises(RuntimeError): + articulation_data.default_root_vel = wp.zeros(num_instances, dtype=wp.spatial_vectorf, device=device) + with pytest.raises(RuntimeError): + articulation_data.default_joint_pos = wp.zeros((num_instances, num_dofs), dtype=wp.float32, device=device) + with pytest.raises(RuntimeError): + articulation_data.default_joint_vel = wp.zeros((num_instances, num_dofs), dtype=wp.float32, device=device) + + +## +# Test Cases -- Joint Commands (Set into the simulation). +## + + +class TestJointCommandsSetIntoSimulation: + """Tests the following properties: + - joint_pos_target + - joint_vel_target + - joint_effort_target + + Runs the following checks: + - Checks that their types and shapes are correct. + - Checks that the returned values are pointers to the internal data. + """ + + def _setup_method(self, num_instances: int, num_dofs: int, device: str) -> ArticulationData: + mock_view = MockNewtonArticulationView(num_instances, 1, num_dofs, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + + return articulation_data + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_initialized_to_zero(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test that the joint commands are initialized to zero.""" + # Setup the articulation data + articulation_data = self._setup_method(num_instances, num_dofs, device) + # Check the types is correct + assert articulation_data.joint_pos_target.dtype is wp.float32 + assert articulation_data.joint_vel_target.dtype is wp.float32 + assert articulation_data.joint_effort.dtype is wp.float32 + # Check the shape is correct + assert articulation_data.joint_pos_target.shape == (num_instances, num_dofs) + assert articulation_data.joint_vel_target.shape == (num_instances, num_dofs) + assert articulation_data.joint_effort.shape == (num_instances, num_dofs) + # Check the values are zero + assert torch.all( + wp.to_torch(articulation_data.joint_pos_target) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.joint_vel_target) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.joint_effort) == torch.zeros((num_instances, num_dofs), device=device) + ) + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_returns_reference(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test that the joint commands return a reference to the internal data.""" + # Setup the articulation data + articulation_data = self._setup_method(num_instances, num_dofs, device) + # Get the pointers + joint_pos_target = articulation_data.joint_pos_target + joint_vel_target = articulation_data.joint_vel_target + joint_effort = articulation_data.joint_effort + # Check that they are zeros + assert torch.all(wp.to_torch(joint_pos_target) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_vel_target) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_effort) == torch.zeros((num_instances, num_dofs), device=device)) + # Assign a different value to the internal data + articulation_data.joint_pos_target.fill_(1.0) + articulation_data.joint_vel_target.fill_(1.0) + articulation_data.joint_effort.fill_(1.0) + # Check that the joint commands return the new value + assert torch.all(wp.to_torch(joint_pos_target) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_vel_target) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_effort) == torch.ones((num_instances, num_dofs), device=device)) + # Assign a different value to the pointers + joint_pos_target.fill_(2.0) + joint_vel_target.fill_(2.0) + joint_effort.fill_(2.0) + # Check that the internal data has been updated + assert torch.all( + wp.to_torch(articulation_data.joint_pos_target) + == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + assert torch.all( + wp.to_torch(articulation_data.joint_vel_target) + == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + assert torch.all( + wp.to_torch(articulation_data.joint_effort) == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + + +## +# Test Cases -- Joint Commands (Explicit actuators). +## + + +class TestJointCommandsExplicitActuators: + """Tests the following properties: + - computed_effort + - applied_effort + - actuator_stiffness + - actuator_damping + - actuator_position_target + - actuator_velocity_target + - actuator_effort_target + + Runs the following checks: + - Checks that their types and shapes are correct. + - Checks that the returned values are pointers to the internal data. + """ + + def _setup_method(self, num_instances: int, num_dofs: int, device: str) -> ArticulationData: + mock_view = MockNewtonArticulationView(num_instances, 1, num_dofs, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + + return articulation_data + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_initialized_to_zero(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test that the explicit actuator properties are initialized to zero.""" + # Setup the articulation data + articulation_data = self._setup_method(num_instances, num_dofs, device) + # Check the types are correct + assert articulation_data.computed_effort.dtype is wp.float32 + assert articulation_data.applied_effort.dtype is wp.float32 + assert articulation_data.actuator_stiffness.dtype is wp.float32 + assert articulation_data.actuator_damping.dtype is wp.float32 + assert articulation_data.actuator_position_target.dtype is wp.float32 + assert articulation_data.actuator_velocity_target.dtype is wp.float32 + assert articulation_data.actuator_effort_target.dtype is wp.float32 + # Check the shapes are correct + assert articulation_data.computed_effort.shape == (num_instances, num_dofs) + assert articulation_data.applied_effort.shape == (num_instances, num_dofs) + assert articulation_data.actuator_stiffness.shape == (num_instances, num_dofs) + assert articulation_data.actuator_damping.shape == (num_instances, num_dofs) + assert articulation_data.actuator_position_target.shape == (num_instances, num_dofs) + assert articulation_data.actuator_velocity_target.shape == (num_instances, num_dofs) + assert articulation_data.actuator_effort_target.shape == (num_instances, num_dofs) + # Check the values are zero + assert torch.all( + wp.to_torch(articulation_data.computed_effort) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.applied_effort) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.actuator_stiffness) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.actuator_damping) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.actuator_position_target) + == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.actuator_velocity_target) + == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.actuator_effort_target) + == torch.zeros((num_instances, num_dofs), device=device) + ) + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_returns_reference(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test that the explicit actuator properties return a reference to the internal data.""" + # Setup the articulation data + articulation_data = self._setup_method(num_instances, num_dofs, device) + # Get the pointers + computed_effort = articulation_data.computed_effort + applied_effort = articulation_data.applied_effort + actuator_stiffness = articulation_data.actuator_stiffness + actuator_damping = articulation_data.actuator_damping + actuator_position_target = articulation_data.actuator_position_target + actuator_velocity_target = articulation_data.actuator_velocity_target + actuator_effort_target = articulation_data.actuator_effort_target + # Check that they are zeros + assert torch.all(wp.to_torch(computed_effort) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(applied_effort) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(actuator_stiffness) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(actuator_damping) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(actuator_position_target) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(actuator_velocity_target) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(actuator_effort_target) == torch.zeros((num_instances, num_dofs), device=device)) + # Assign a different value to the internal data + articulation_data.computed_effort.fill_(1.0) + articulation_data.applied_effort.fill_(1.0) + articulation_data.actuator_stiffness.fill_(1.0) + articulation_data.actuator_damping.fill_(1.0) + articulation_data.actuator_position_target.fill_(1.0) + articulation_data.actuator_velocity_target.fill_(1.0) + articulation_data.actuator_effort_target.fill_(1.0) + # Check that the properties return the new value + assert torch.all(wp.to_torch(computed_effort) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(applied_effort) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(actuator_stiffness) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(actuator_damping) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(actuator_position_target) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(actuator_velocity_target) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(actuator_effort_target) == torch.ones((num_instances, num_dofs), device=device)) + # Assign a different value to the pointers + computed_effort.fill_(2.0) + applied_effort.fill_(2.0) + actuator_stiffness.fill_(2.0) + actuator_damping.fill_(2.0) + actuator_position_target.fill_(2.0) + actuator_velocity_target.fill_(2.0) + actuator_effort_target.fill_(2.0) + # Check that the internal data has been updated + assert torch.all( + wp.to_torch(articulation_data.computed_effort) == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + assert torch.all( + wp.to_torch(articulation_data.applied_effort) == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + assert torch.all( + wp.to_torch(articulation_data.actuator_stiffness) + == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + assert torch.all( + wp.to_torch(articulation_data.actuator_damping) + == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + assert torch.all( + wp.to_torch(articulation_data.actuator_position_target) + == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + assert torch.all( + wp.to_torch(articulation_data.actuator_velocity_target) + == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + assert torch.all( + wp.to_torch(articulation_data.actuator_effort_target) + == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + + +## +# Test Cases -- Joint Properties (Set into Simulation). +## + + +class TestJointPropertiesSetIntoSimulation: + """Tests the following properties: + - joint_stiffness + - joint_damping + - joint_armature + - joint_friction_coeff + - joint_pos_limits_lower + - joint_pos_limits_upper + - joint_pos_limits (read-only, computed from lower and upper) + - joint_vel_limits + - joint_effort_limits + + Runs the following checks: + - Checks that their types and shapes are correct. + - Checks that the returned values are pointers to the internal data. + + .. note:: joint_pos_limits is read-only and does not change the joint position limits. + """ + + def _setup_method( + self, num_instances: int, num_dofs: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, num_dofs, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + + # return the mock view, so that it doesn't get garbage collected + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_initialized_to_zero(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test that the joint properties are initialized to zero (or ones for limits).""" + # Setup the articulation data + articulation_data, _ = self._setup_method(num_instances, num_dofs, device) + + # Check the types are correct + assert articulation_data.joint_stiffness.dtype is wp.float32 + assert articulation_data.joint_damping.dtype is wp.float32 + assert articulation_data.joint_armature.dtype is wp.float32 + assert articulation_data.joint_friction_coeff.dtype is wp.float32 + assert articulation_data.joint_pos_limits_lower.dtype is wp.float32 + assert articulation_data.joint_pos_limits_upper.dtype is wp.float32 + assert articulation_data.joint_pos_limits.dtype is wp.vec2f + assert articulation_data.joint_vel_limits.dtype is wp.float32 + assert articulation_data.joint_effort_limits.dtype is wp.float32 + + # Check the shapes are correct + assert articulation_data.joint_stiffness.shape == (num_instances, num_dofs) + assert articulation_data.joint_damping.shape == (num_instances, num_dofs) + assert articulation_data.joint_armature.shape == (num_instances, num_dofs) + assert articulation_data.joint_friction_coeff.shape == (num_instances, num_dofs) + assert articulation_data.joint_pos_limits_lower.shape == (num_instances, num_dofs) + assert articulation_data.joint_pos_limits_upper.shape == (num_instances, num_dofs) + assert articulation_data.joint_pos_limits.shape == (num_instances, num_dofs) + assert articulation_data.joint_vel_limits.shape == (num_instances, num_dofs) + assert articulation_data.joint_effort_limits.shape == (num_instances, num_dofs) + + # Check the values are zero + assert torch.all( + wp.to_torch(articulation_data.joint_stiffness) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.joint_damping) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.joint_armature) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.joint_friction_coeff) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.joint_pos_limits_lower) + == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.joint_pos_limits_upper) + == torch.zeros((num_instances, num_dofs), device=device) + ) + # joint_pos_limits should be (0, 0) for each joint since both lower and upper are 0 + joint_pos_limits = wp.to_torch(articulation_data.joint_pos_limits) + assert torch.all(joint_pos_limits == torch.zeros((num_instances, num_dofs, 2), device=device)) + # vel_limits and effort_limits are initialized to zeros in the mock + assert torch.all( + wp.to_torch(articulation_data.joint_vel_limits) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.joint_effort_limits) == torch.zeros((num_instances, num_dofs), device=device) + ) + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_returns_reference(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test that the joint properties return a reference to the internal data. + + Note: joint_pos_limits is read-only and always returns a new computed array. + """ + # Setup the articulation data + articulation_data, _ = self._setup_method(num_instances, num_dofs, device) + + # Get the pointers + joint_stiffness = articulation_data.joint_stiffness + joint_damping = articulation_data.joint_damping + joint_armature = articulation_data.joint_armature + joint_friction_coeff = articulation_data.joint_friction_coeff + joint_pos_limits_lower = articulation_data.joint_pos_limits_lower + joint_pos_limits_upper = articulation_data.joint_pos_limits_upper + joint_vel_limits = articulation_data.joint_vel_limits + joint_effort_limits = articulation_data.joint_effort_limits + + # Check that they have initial values (zeros or ones based on mock) + assert torch.all(wp.to_torch(joint_stiffness) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_damping) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_armature) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_friction_coeff) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_pos_limits_lower) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_pos_limits_upper) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_vel_limits) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_effort_limits) == torch.zeros((num_instances, num_dofs), device=device)) + + # Assign a different value to the internal data + articulation_data.joint_stiffness.fill_(1.0) + articulation_data.joint_damping.fill_(1.0) + articulation_data.joint_armature.fill_(1.0) + articulation_data.joint_friction_coeff.fill_(1.0) + articulation_data.joint_pos_limits_lower.fill_(-1.0) + articulation_data.joint_pos_limits_upper.fill_(1.0) + articulation_data.joint_vel_limits.fill_(2.0) + articulation_data.joint_effort_limits.fill_(2.0) + + # Check that the properties return the new value (reference behavior) + assert torch.all(wp.to_torch(joint_stiffness) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_damping) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_armature) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_friction_coeff) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all( + wp.to_torch(joint_pos_limits_lower) == torch.ones((num_instances, num_dofs), device=device) * -1.0 + ) + assert torch.all(wp.to_torch(joint_pos_limits_upper) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_vel_limits) == torch.ones((num_instances, num_dofs), device=device) * 2.0) + assert torch.all(wp.to_torch(joint_effort_limits) == torch.ones((num_instances, num_dofs), device=device) * 2.0) + + # Check that joint_pos_limits is computed correctly from lower and upper + joint_pos_limits = wp.to_torch(articulation_data.joint_pos_limits) + expected_limits = torch.stack( + [ + torch.ones((num_instances, num_dofs), device=device) * -1.0, + torch.ones((num_instances, num_dofs), device=device), + ], + dim=-1, + ) + assert torch.all(joint_pos_limits == expected_limits) + + # Assign a different value to the pointers + joint_stiffness.fill_(3.0) + joint_damping.fill_(3.0) + joint_armature.fill_(3.0) + joint_friction_coeff.fill_(3.0) + joint_pos_limits_lower.fill_(-2.0) + joint_pos_limits_upper.fill_(2.0) + joint_vel_limits.fill_(4.0) + joint_effort_limits.fill_(4.0) + + # Check that the internal data has been updated + assert torch.all( + wp.to_torch(articulation_data.joint_stiffness) == torch.ones((num_instances, num_dofs), device=device) * 3.0 + ) + assert torch.all( + wp.to_torch(articulation_data.joint_damping) == torch.ones((num_instances, num_dofs), device=device) * 3.0 + ) + assert torch.all( + wp.to_torch(articulation_data.joint_armature) == torch.ones((num_instances, num_dofs), device=device) * 3.0 + ) + assert torch.all( + wp.to_torch(articulation_data.joint_friction_coeff) + == torch.ones((num_instances, num_dofs), device=device) * 3.0 + ) + assert torch.all( + wp.to_torch(articulation_data.joint_pos_limits_lower) + == torch.ones((num_instances, num_dofs), device=device) * -2.0 + ) + assert torch.all( + wp.to_torch(articulation_data.joint_pos_limits_upper) + == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + assert torch.all( + wp.to_torch(articulation_data.joint_vel_limits) + == torch.ones((num_instances, num_dofs), device=device) * 4.0 + ) + assert torch.all( + wp.to_torch(articulation_data.joint_effort_limits) + == torch.ones((num_instances, num_dofs), device=device) * 4.0 + ) + + # Verify joint_pos_limits reflects the updated lower and upper values + joint_pos_limits_updated = wp.to_torch(articulation_data.joint_pos_limits) + expected_limits_updated = torch.stack( + [ + torch.ones((num_instances, num_dofs), device=device) * -2.0, + torch.ones((num_instances, num_dofs), device=device) * 2.0, + ], + dim=-1, + ) + assert torch.all(joint_pos_limits_updated == expected_limits_updated) + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_joint_pos_limits_is_read_only(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test that joint_pos_limits returns a new array each time (not a reference). + + Unlike other joint properties, joint_pos_limits is computed on-the-fly from + joint_pos_limits_lower and joint_pos_limits_upper. Modifying the returned array + should not affect the underlying data. + """ + # Setup the articulation data + articulation_data, _ = self._setup_method(num_instances, num_dofs, device) + + # Get joint_pos_limits twice + limits1 = articulation_data.joint_pos_limits + limits2 = articulation_data.joint_pos_limits + + # They should be separate arrays (not the same reference) + # Modifying one should not affect the other + limits1.fill_(2.0) + + # limits2 should be changed to 2.0 + assert torch.all(wp.to_torch(limits2) == torch.ones((num_instances, num_dofs, 2), device=device) * 2.0) + + # The underlying lower and upper should be unchanged + assert torch.all( + wp.to_torch(articulation_data.joint_pos_limits_lower) + == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.joint_pos_limits_upper) + == torch.zeros((num_instances, num_dofs), device=device) + ) + + +## +# Test Cases -- Joint Properties (Custom). +## + + +class TestJointPropertiesCustom: + """Tests the following properties: + - joint_dynamic_friction_coeff + - joint_viscous_friction_coeff + - soft_joint_pos_limits + - soft_joint_vel_limits + - gear_ratio + + Runs the following checks: + - Checks that their types and shapes are correct. + - Checks that the returned values are pointers to the internal data. + + .. note:: gear_ratio is initialized to ones (not zeros). + """ + + def _setup_method(self, num_instances: int, num_dofs: int, device: str) -> ArticulationData: + mock_view = MockNewtonArticulationView(num_instances, 1, num_dofs, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + + return articulation_data + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_initialized_correctly(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test that the custom joint properties are initialized correctly.""" + # Setup the articulation data + articulation_data = self._setup_method(num_instances, num_dofs, device) + + # Check the types are correct + assert articulation_data.joint_dynamic_friction_coeff.dtype is wp.float32 + assert articulation_data.joint_viscous_friction_coeff.dtype is wp.float32 + assert articulation_data.soft_joint_pos_limits.dtype is wp.vec2f + assert articulation_data.soft_joint_vel_limits.dtype is wp.float32 + assert articulation_data.gear_ratio.dtype is wp.float32 + + # Check the shapes are correct + assert articulation_data.joint_dynamic_friction_coeff.shape == (num_instances, num_dofs) + assert articulation_data.joint_viscous_friction_coeff.shape == (num_instances, num_dofs) + assert articulation_data.soft_joint_pos_limits.shape == (num_instances, num_dofs) + assert articulation_data.soft_joint_vel_limits.shape == (num_instances, num_dofs) + assert articulation_data.gear_ratio.shape == (num_instances, num_dofs) + + # Check the values are initialized correctly + # Most are zeros, but gear_ratio is initialized to ones + assert torch.all( + wp.to_torch(articulation_data.joint_dynamic_friction_coeff) + == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.joint_viscous_friction_coeff) + == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.soft_joint_pos_limits) + == torch.zeros((num_instances, num_dofs, 2), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.soft_joint_vel_limits) + == torch.zeros((num_instances, num_dofs), device=device) + ) + # gear_ratio is initialized to ones + assert torch.all( + wp.to_torch(articulation_data.gear_ratio) == torch.ones((num_instances, num_dofs), device=device) + ) + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_returns_reference(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test that the custom joint properties return a reference to the internal data.""" + # Setup the articulation data + articulation_data = self._setup_method(num_instances, num_dofs, device) + + # Get the pointers + joint_dynamic_friction_coeff = articulation_data.joint_dynamic_friction_coeff + joint_viscous_friction_coeff = articulation_data.joint_viscous_friction_coeff + soft_joint_pos_limits = articulation_data.soft_joint_pos_limits + soft_joint_vel_limits = articulation_data.soft_joint_vel_limits + gear_ratio = articulation_data.gear_ratio + + # Check that they have initial values + assert torch.all( + wp.to_torch(joint_dynamic_friction_coeff) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(joint_viscous_friction_coeff) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all(wp.to_torch(soft_joint_pos_limits) == torch.zeros((num_instances, num_dofs, 2), device=device)) + assert torch.all(wp.to_torch(soft_joint_vel_limits) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(gear_ratio) == torch.ones((num_instances, num_dofs), device=device)) + + # Assign a different value to the internal data + articulation_data.joint_dynamic_friction_coeff.fill_(1.0) + articulation_data.joint_viscous_friction_coeff.fill_(1.0) + articulation_data.soft_joint_pos_limits.fill_(1.0) + articulation_data.soft_joint_vel_limits.fill_(1.0) + articulation_data.gear_ratio.fill_(2.0) + + # Check that the properties return the new value (reference behavior) + assert torch.all( + wp.to_torch(joint_dynamic_friction_coeff) == torch.ones((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(joint_viscous_friction_coeff) == torch.ones((num_instances, num_dofs), device=device) + ) + assert torch.all(wp.to_torch(soft_joint_pos_limits) == torch.ones((num_instances, num_dofs, 2), device=device)) + assert torch.all(wp.to_torch(soft_joint_vel_limits) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(gear_ratio) == torch.ones((num_instances, num_dofs), device=device) * 2.0) + + # Assign a different value to the pointers + joint_dynamic_friction_coeff.fill_(3.0) + joint_viscous_friction_coeff.fill_(3.0) + soft_joint_pos_limits.fill_(3.0) + soft_joint_vel_limits.fill_(3.0) + gear_ratio.fill_(4.0) + + # Check that the internal data has been updated + assert torch.all( + wp.to_torch(articulation_data.joint_dynamic_friction_coeff) + == torch.ones((num_instances, num_dofs), device=device) * 3.0 + ) + assert torch.all( + wp.to_torch(articulation_data.joint_viscous_friction_coeff) + == torch.ones((num_instances, num_dofs), device=device) * 3.0 + ) + assert torch.all( + wp.to_torch(articulation_data.soft_joint_pos_limits) + == torch.ones((num_instances, num_dofs, 2), device=device) * 3.0 + ) + assert torch.all( + wp.to_torch(articulation_data.soft_joint_vel_limits) + == torch.ones((num_instances, num_dofs), device=device) * 3.0 + ) + assert torch.all( + wp.to_torch(articulation_data.gear_ratio) == torch.ones((num_instances, num_dofs), device=device) * 4.0 + ) + + +## +# Test Cases -- Fixed Tendon Properties. +## + + +# TODO: Update these tests when fixed tendon support is added to Newton. +class TestFixedTendonProperties: + """Tests the following properties: + - fixed_tendon_stiffness + - fixed_tendon_damping + - fixed_tendon_limit_stiffness + - fixed_tendon_rest_length + - fixed_tendon_offset + - fixed_tendon_pos_limits + + Currently, all these properties raise NotImplementedError as fixed tendons + are not supported in Newton. + + Runs the following checks: + - Checks that all properties raise NotImplementedError. + """ + + def _setup_method(self, num_instances: int, num_dofs: int, device: str) -> ArticulationData: + mock_view = MockNewtonArticulationView(num_instances, 1, num_dofs, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + + return articulation_data + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_all_fixed_tendon_properties_not_implemented(self, mock_newton_manager, device: str): + """Test that all fixed tendon properties raise NotImplementedError.""" + articulation_data = self._setup_method(1, 1, device) + + with pytest.raises(NotImplementedError): + _ = articulation_data.fixed_tendon_stiffness + with pytest.raises(NotImplementedError): + _ = articulation_data.fixed_tendon_damping + with pytest.raises(NotImplementedError): + _ = articulation_data.fixed_tendon_limit_stiffness + with pytest.raises(NotImplementedError): + _ = articulation_data.fixed_tendon_rest_length + with pytest.raises(NotImplementedError): + _ = articulation_data.fixed_tendon_offset + with pytest.raises(NotImplementedError): + _ = articulation_data.fixed_tendon_pos_limits + + +## +# Test Cases -- Spatial Tendon Properties. +## + + +# TODO: Update these tests when spatial tendon support is added to Newton. +class TestSpatialTendonProperties: + """Tests the following properties: + - spatial_tendon_stiffness + - spatial_tendon_damping + - spatial_tendon_limit_stiffness + - spatial_tendon_offset + + Currently, all these properties raise NotImplementedError as spatial tendons + are not supported in Newton. + + Runs the following checks: + - Checks that all properties raise NotImplementedError. + """ + + def _setup_method(self, num_instances: int, num_dofs: int, device: str) -> ArticulationData: + mock_view = MockNewtonArticulationView(num_instances, 1, num_dofs, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + + return articulation_data + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_all_spatial_tendon_properties_not_implemented(self, mock_newton_manager, device: str): + """Test that all spatial tendon properties raise NotImplementedError.""" + articulation_data = self._setup_method(1, 1, device) + + with pytest.raises(NotImplementedError): + _ = articulation_data.spatial_tendon_stiffness + with pytest.raises(NotImplementedError): + _ = articulation_data.spatial_tendon_damping + with pytest.raises(NotImplementedError): + _ = articulation_data.spatial_tendon_limit_stiffness + with pytest.raises(NotImplementedError): + _ = articulation_data.spatial_tendon_offset + + +## +# Test Cases -- Root state properties. +## + + +class TestRootLinkPoseW: + """Tests the root link pose property + + This value is read from the simulation. There is no math to check for. + + Runs the following checks: + - Checks that the returned value is a pointer to the internal data. + - Checks that the returned value is correct. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_root_link_pose_w(self, mock_newton_manager, num_instances: int, device: str): + """Test that the root link pose property returns a pointer to the internal data.""" + articulation_data, _ = self._setup_method(num_instances, device) + + # Check the type and shape + assert articulation_data.root_link_pose_w.shape == (num_instances,) + assert articulation_data.root_link_pose_w.dtype == wp.transformf + + # Mock data is initialized to zeros + assert torch.all(wp.to_torch(articulation_data.root_link_pose_w) == torch.zeros((1, 7), device=device)) + + # Get the property + root_link_pose_w = articulation_data.root_link_pose_w + + # Assign a different value to the internal data + articulation_data.root_link_pose_w.fill_(1.0) + + # Check that the property returns the new value (reference behavior) + assert torch.all(wp.to_torch(articulation_data.root_link_pose_w) == torch.ones((1, 7), device=device)) + + # Assign a different value to the pointers + root_link_pose_w.fill_(2.0) + + # Check that the internal data has been updated + assert torch.all(wp.to_torch(articulation_data.root_link_pose_w) == torch.ones((1, 7), device=device) * 2.0) + + +class TestRootLinkVelW: + """Tests the root link velocity property + + This value is derived from the root center of mass velocity. To ensure that the value is correctly computed, + we will compare the calculated value to the one currently calculated in the version 2.3.1 of IsaacLab. + + Runs the following checks: + - Checks that the returned value is a pointer to the internal data. + - Checks that the returned value is correct. + - Checks that the timestamp is updated correctly. + - Checks that the data is invalidated when the timestamp is updated. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_correctness(self, mock_newton_manager, num_instances: int, device: str): + """Test that the root link velocity property is correctly computed.""" + articulation_data, mock_view = self._setup_method(num_instances, device) + + # Check the type and shape + assert articulation_data.root_link_vel_w.shape == (num_instances,) + assert articulation_data.root_link_vel_w.dtype == wp.spatial_vectorf + + # Mock data is initialized to zeros + assert torch.all( + wp.to_torch(articulation_data.root_link_vel_w) == torch.zeros((num_instances, 6), device=device) + ) + + for i in range(10): + articulation_data._sim_timestamp = i + 1.0 + # Generate random com velocity and body com position + com_vel = torch.rand((num_instances, 6), device=device) + body_com_pos = torch.rand((num_instances, 1, 3), device=device) + root_link_pose = torch.zeros((num_instances, 7), device=device) + root_link_pose[:, 3:] = torch.randn((num_instances, 4), device=device) + root_link_pose[:, 3:] = torch.nn.functional.normalize(root_link_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_link_pose, dtype=wp.transformf), + root_velocities=wp.from_torch(com_vel, dtype=wp.spatial_vectorf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # Use the original IsaacLab code to compute the root link velocities + vel = com_vel.clone() + # TODO: Move the function from math_utils to a test utils file. Decoupling it from changes in math_utils. + vel[:, :3] += torch.linalg.cross( + vel[:, 3:], math_utils.quat_apply(root_link_pose[:, 3:], -body_com_pos[:, 0]), dim=-1 + ) + + # Compare the computed value to the one from the articulation data + assert torch.allclose(wp.to_torch(articulation_data.root_link_vel_w), vel, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_update_timestamp(self, mock_newton_manager, device: str): + """Test that the timestamp is updated correctly.""" + articulation_data, mock_view = self._setup_method(1, device) + + # Check that the timestamp is initialized to -1.0 + assert articulation_data._root_link_vel_w.timestamp == -1.0 + + # Check that the data class timestamp is initialized to 0.0 + assert articulation_data._sim_timestamp == 0.0 + + # Request the property + value = wp.to_torch(articulation_data.root_link_vel_w).clone() + + # Check that the timestamp is updated. The timestamp should be the same as the data class timestamp. + assert articulation_data._root_link_vel_w.timestamp == articulation_data._sim_timestamp + + # Update the root_com_vel_w + mock_view.set_mock_data( + root_velocities=wp.from_torch(torch.rand((1, 6), device=device), dtype=wp.spatial_vectorf), + ) + + # Check that the property value was not updated + assert torch.all(wp.to_torch(articulation_data.root_link_vel_w) == value) + + # Update the data class timestamp + articulation_data._sim_timestamp = 1.0 + + # Check that the property timestamp was not updated + assert articulation_data._root_link_vel_w.timestamp != articulation_data._sim_timestamp + + # Check that the property value was updated + assert torch.all(wp.to_torch(articulation_data.root_link_vel_w) != value) + + +class TestRootComPoseW: + """Tests the root center of mass pose property + + This value is derived from the root link pose and the body com position. To ensure that the value is correctly computed, + we will compare the calculated value to the one currently calculated in the version 2.3.1 of IsaacLab. + + Runs the following checks: + - Checks that the returned value is a pointer to the internal data. + - Checks that the returned value is correct. + - Checks that the timestamp is updated correctly. + - Checks that the data is invalidated when the timestamp is updated. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_root_com_pose_w(self, mock_newton_manager, num_instances: int, device: str): + """Test that the root center of mass pose property returns a pointer to the internal data.""" + articulation_data, mock_view = self._setup_method(num_instances, device) + + # Check the type and shape + assert articulation_data.root_com_pose_w.shape == (num_instances,) + assert articulation_data.root_com_pose_w.dtype == wp.transformf + + # Mock data is initialized to zeros + assert torch.all( + wp.to_torch(articulation_data.root_com_pose_w) == torch.zeros((num_instances, 7), device=device) + ) + + for i in range(10): + articulation_data._sim_timestamp = i + 1.0 + # Generate random root link pose and body com position + root_link_pose = torch.zeros((num_instances, 7), device=device) + root_link_pose[:, :3] = torch.rand((num_instances, 3), device=device) + root_link_pose[:, 3:] = torch.randn((num_instances, 4), device=device) + root_link_pose[:, 3:] = torch.nn.functional.normalize(root_link_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + body_com_pos = torch.rand((num_instances, 1, 3), device=device) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_link_pose, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # Use the original IsaacLab code to compute the root center of mass pose + root_link_pos_w = root_link_pose[:, :3] + root_link_quat_w = root_link_pose[:, 3:] + body_com_pos_b = body_com_pos.clone() + body_com_quat_b = torch.zeros((num_instances, 1, 4), device=device) + body_com_quat_b[:, :, 3] = 1.0 + # --- IL 2.3.1 code --- + pos, quat = math_utils.combine_frame_transforms( + root_link_pos_w, root_link_quat_w, body_com_pos_b[:, 0], body_com_quat_b[:, 0] + ) + # --- + root_com_pose = torch.cat((pos, quat), dim=-1) + + # Compare the computed value to the one from the articulation data + assert torch.allclose(wp.to_torch(articulation_data.root_com_pose_w), root_com_pose, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_update_timestamp(self, mock_newton_manager, device: str): + """Test that the timestamp is updated correctly.""" + articulation_data, mock_view = self._setup_method(1, device) + + # Check that the timestamp is initialized to -1.0 + assert articulation_data._root_com_pose_w.timestamp == -1.0 + + # Check that the data class timestamp is initialized to 0.0 + assert articulation_data._sim_timestamp == 0.0 + + # Request the property + value = wp.to_torch(articulation_data.root_com_pose_w).clone() + + # Check that the timestamp is updated. The timestamp should be the same as the data class timestamp. + assert articulation_data._root_com_pose_w.timestamp == articulation_data._sim_timestamp + + # Update the root_com_vel_w + mock_view.set_mock_data( + root_transforms=wp.from_torch(torch.rand((1, 7), device=device), dtype=wp.transformf), + ) + + # Check that the property value was not updated + assert torch.all(wp.to_torch(articulation_data.root_com_pose_w) == value) + + # Update the data class timestamp + articulation_data._sim_timestamp = 1.0 + + # Check that the property timestamp was not updated + assert articulation_data._root_com_pose_w.timestamp != articulation_data._sim_timestamp + + # Check that the property value was updated + assert torch.all(wp.to_torch(articulation_data.root_com_pose_w) != value) + + +class TestRootComVelW: + """Tests the root center of mass velocity property + + This value is read from the simulation. There is no math to check for. + + Checks that the returned value is a pointer to the internal data. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_root_com_vel_w(self, mock_newton_manager, num_instances: int, device: str): + """Test that the root center of mass velocity property returns a pointer to the internal data.""" + articulation_data, _ = self._setup_method(num_instances, device) + + # Check the type and shape + assert articulation_data.root_com_vel_w.shape == (num_instances,) + assert articulation_data.root_com_vel_w.dtype == wp.spatial_vectorf + + # Mock data is initialized to zeros + assert torch.all( + wp.to_torch(articulation_data.root_com_vel_w) == torch.zeros((num_instances, 6), device=device) + ) + + # Get the property + root_com_vel_w = articulation_data.root_com_vel_w + + # Assign a different value to the internal data + articulation_data.root_com_vel_w.fill_(1.0) + + # Check that the property returns the new value (reference behavior) + assert torch.all(wp.to_torch(articulation_data.root_com_vel_w) == torch.ones((num_instances, 6), device=device)) + + # Assign a different value to the pointers + root_com_vel_w.fill_(2.0) + + # Check that the internal data has been updated + assert torch.all( + wp.to_torch(articulation_data.root_com_vel_w) == torch.ones((num_instances, 6), device=device) * 2.0 + ) + + +class TestRootState: + """Tests the root state properties + + Test the root state properties are correctly updated from the pose and velocity properties. + Tests the following properties: + - root_state_w + - root_link_state_w + - root_com_state_w + + For each property, we run the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is correctly assembled from pose and velocity. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_all_root_state_properties(self, mock_newton_manager, num_instances: int, device: str): + """Test that all root state properties correctly combine pose and velocity.""" + articulation_data, mock_view = self._setup_method(num_instances, device) + + # Generate random mock data + for i in range(5): + articulation_data._sim_timestamp = i + 1.0 + + # Generate random root link pose + root_link_pose = torch.zeros((num_instances, 7), device=device) + root_link_pose[:, :3] = torch.rand((num_instances, 3), device=device) + root_link_pose[:, 3:] = torch.randn((num_instances, 4), device=device) + root_link_pose[:, 3:] = torch.nn.functional.normalize(root_link_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + + # Generate random velocities and com position + com_vel = torch.rand((num_instances, 6), device=device) + body_com_pos = torch.rand((num_instances, 1, 3), device=device) + + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_link_pose, dtype=wp.transformf), + root_velocities=wp.from_torch(com_vel, dtype=wp.spatial_vectorf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # --- Test root_state_w --- + # Combines root_link_pose_w with root_com_vel_w + root_state = wp.to_torch(articulation_data.root_state_w) + expected_root_state = torch.cat([root_link_pose, com_vel], dim=-1) + + assert root_state.shape == (num_instances, 13) + assert torch.allclose(root_state, expected_root_state, atol=1e-6, rtol=1e-6) + + # --- Test root_link_state_w --- + # Combines root_link_pose_w with root_link_vel_w + root_link_state = wp.to_torch(articulation_data.root_link_state_w) + + # Compute expected root_link_vel from com_vel (same as TestRootLinkVelW) + root_link_vel = com_vel.clone() + root_link_vel[:, :3] += torch.linalg.cross( + root_link_vel[:, 3:], math_utils.quat_apply(root_link_pose[:, 3:], -body_com_pos[:, 0]), dim=-1 + ) + expected_root_link_state = torch.cat([root_link_pose, root_link_vel], dim=-1) + + assert root_link_state.shape == (num_instances, 13) + assert torch.allclose(root_link_state, expected_root_link_state, atol=1e-6, rtol=1e-6) + + # --- Test root_com_state_w --- + # Combines root_com_pose_w with root_com_vel_w + root_com_state = wp.to_torch(articulation_data.root_com_state_w) + + # Compute expected root_com_pose from root_link_pose and body_com_pos (same as TestRootComPoseW) + body_com_quat_b = torch.zeros((num_instances, 4), device=device) + body_com_quat_b[:, 3] = 1.0 + root_com_pos, root_com_quat = math_utils.combine_frame_transforms( + root_link_pose[:, :3], root_link_pose[:, 3:], body_com_pos[:, 0], body_com_quat_b + ) + expected_root_com_state = torch.cat([root_com_pos, root_com_quat, com_vel], dim=-1) + + assert root_com_state.shape == (num_instances, 13) + assert torch.allclose(root_com_state, expected_root_com_state, atol=1e-6, rtol=1e-6) + + +## +# Test Cases -- Body state properties. +## + + +class TestBodyMassInertia: + """Tests the body mass and inertia properties. + + These values are read directly from the simulation bindings. + + Tests the following properties: + - body_mass + - body_inertia + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is a reference to the internal data. + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_body_mass_and_inertia(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that body_mass and body_inertia have correct types, shapes, and reference behavior.""" + articulation_data, mock_view = self._setup_method(num_instances, num_bodies, device) + + # --- Test body_mass --- + # Check the type and shape + assert articulation_data.body_mass.shape == (num_instances, num_bodies) + assert articulation_data.body_mass.dtype == wp.float32 + + # Mock data initializes body_mass to ones + assert torch.all( + wp.to_torch(articulation_data.body_mass) == torch.zeros((num_instances, num_bodies), device=device) + ) + + # Get the property reference + body_mass_ref = articulation_data.body_mass + + # Assign a different value to the internal data via property + articulation_data.body_mass.fill_(2.0) + + # Check that the property returns the new value (reference behavior) + assert torch.all( + wp.to_torch(articulation_data.body_mass) == torch.ones((num_instances, num_bodies), device=device) * 2.0 + ) + + # Assign a different value via reference + body_mass_ref.fill_(3.0) + + # Check that the internal data has been updated + assert torch.all( + wp.to_torch(articulation_data.body_mass) == torch.ones((num_instances, num_bodies), device=device) * 3.0 + ) + + # --- Test body_inertia --- + # Check the type and shape + assert articulation_data.body_inertia.shape == (num_instances, num_bodies) + assert articulation_data.body_inertia.dtype == wp.mat33f + + # Mock data initializes body_inertia to zeros + expected_inertia = torch.zeros((num_instances, num_bodies, 3, 3), device=device) + assert torch.all(wp.to_torch(articulation_data.body_inertia) == expected_inertia) + + # Get the property reference + body_inertia_ref = articulation_data.body_inertia + + # Assign a different value to the internal data via property + articulation_data.body_inertia.fill_(1.0) + + # Check that the property returns the new value (reference behavior) + expected_inertia_ones = torch.ones((num_instances, num_bodies, 3, 3), device=device) + assert torch.all(wp.to_torch(articulation_data.body_inertia) == expected_inertia_ones) + + # Assign a different value via reference + body_inertia_ref.fill_(2.0) + + # Check that the internal data has been updated + expected_inertia_twos = torch.ones((num_instances, num_bodies, 3, 3), device=device) * 2.0 + assert torch.all(wp.to_torch(articulation_data.body_inertia) == expected_inertia_twos) + + +class TestBodyLinkPoseW: + """Tests the body link pose property. + + This value is read directly from the simulation bindings. + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is a reference to the internal data. + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_body_link_pose_w(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that body_link_pose_w has correct type, shape, and reference behavior.""" + articulation_data, _ = self._setup_method(num_instances, num_bodies, device) + + # Check the type and shape + assert articulation_data.body_link_pose_w.shape == (num_instances, num_bodies) + assert articulation_data.body_link_pose_w.dtype == wp.transformf + + # Mock data is initialized to zeros + expected = torch.zeros((num_instances, num_bodies, 7), device=device) + assert torch.all(wp.to_torch(articulation_data.body_link_pose_w) == expected) + + # Get the property reference + body_link_pose_ref = articulation_data.body_link_pose_w + + # Assign a different value via property + articulation_data.body_link_pose_w.fill_(1.0) + + # Check that the property returns the new value (reference behavior) + expected_ones = torch.ones((num_instances, num_bodies, 7), device=device) + assert torch.all(wp.to_torch(articulation_data.body_link_pose_w) == expected_ones) + + # Assign a different value via reference + body_link_pose_ref.fill_(2.0) + + # Check that the internal data has been updated + expected_twos = torch.ones((num_instances, num_bodies, 7), device=device) * 2.0 + assert torch.all(wp.to_torch(articulation_data.body_link_pose_w) == expected_twos) + + +class TestBodyLinkVelW: + """Tests the body link velocity property. + + This value is derived from body COM velocity. To ensure correctness, + we compare against the reference implementation. + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is correctly computed. + - Checks that the timestamp is updated correctly. + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_correctness(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that body_link_vel_w is correctly computed from COM velocity.""" + articulation_data, mock_view = self._setup_method(num_instances, num_bodies, device) + + # Check the type and shape + assert articulation_data.body_link_vel_w.shape == (num_instances, num_bodies) + assert articulation_data.body_link_vel_w.dtype == wp.spatial_vectorf + + # Mock data is initialized to zeros + expected = torch.zeros((num_instances, num_bodies, 6), device=device) + assert torch.all(wp.to_torch(articulation_data.body_link_vel_w) == expected) + + for i in range(5): + articulation_data._sim_timestamp = i + 1.0 + + # Generate random COM velocity and body COM position + com_vel = torch.rand((num_instances, num_bodies, 6), device=device) + body_com_pos = torch.rand((num_instances, num_bodies, 3), device=device) + + # Generate random link poses with normalized quaternions + link_pose = torch.zeros((num_instances, num_bodies, 7), device=device) + link_pose[..., :3] = torch.rand((num_instances, num_bodies, 3), device=device) + link_pose[..., 3:] = torch.randn((num_instances, num_bodies, 4), device=device) + link_pose[..., 3:] = torch.nn.functional.normalize(link_pose[..., 3:], p=2.0, dim=-1, eps=1e-12) + + mock_view.set_mock_data( + link_transforms=wp.from_torch(link_pose, dtype=wp.transformf), + link_velocities=wp.from_torch(com_vel, dtype=wp.spatial_vectorf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # Compute expected link velocity using IsaacLab reference implementation + # vel[:, :3] += cross(vel[:, 3:], quat_apply(quat, -body_com_pos)) + expected_vel = com_vel.clone() + expected_vel[..., :3] += torch.linalg.cross( + expected_vel[..., 3:], + math_utils.quat_apply(link_pose[..., 3:], -body_com_pos), + dim=-1, + ) + + # Compare the computed value + assert torch.allclose(wp.to_torch(articulation_data.body_link_vel_w), expected_vel, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_timestamp_invalidation(self, mock_newton_manager, device: str): + """Test that data is invalidated when timestamp is updated.""" + articulation_data, mock_view = self._setup_method(1, 1, device) + + # Check initial timestamp + assert articulation_data._body_link_vel_w.timestamp == -1.0 + assert articulation_data._sim_timestamp == 0.0 + + # Request the property to trigger computation + value = wp.to_torch(articulation_data.body_link_vel_w).clone() + + # Check that buffer timestamp matches sim timestamp + assert articulation_data._body_link_vel_w.timestamp == articulation_data._sim_timestamp + + # Update mock data without changing sim timestamp + mock_view.set_mock_data( + link_velocities=wp.from_torch(torch.rand((1, 1, 6), device=device), dtype=wp.spatial_vectorf), + ) + + # Value should NOT change (cached) + assert torch.all(wp.to_torch(articulation_data.body_link_vel_w) == value) + + # Update sim timestamp + articulation_data._sim_timestamp = 1.0 + + # Buffer timestamp should now be stale + assert articulation_data._body_link_vel_w.timestamp != articulation_data._sim_timestamp + + # Value should now be recomputed (different from cached) + assert not torch.all(wp.to_torch(articulation_data.body_link_vel_w) == value) + + +class TestBodyComPoseW: + """Tests the body center of mass pose property. + + This value is derived from body link pose and body COM position. + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is correctly computed. + - Checks that the timestamp is updated correctly. + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_correctness(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that body_com_pose_w is correctly computed from link pose and COM position.""" + articulation_data, mock_view = self._setup_method(num_instances, num_bodies, device) + + # Check the type and shape + assert articulation_data.body_com_pose_w.shape == (num_instances, num_bodies) + assert articulation_data.body_com_pose_w.dtype == wp.transformf + + # Mock data is initialized to zeros + expected = torch.zeros((num_instances, num_bodies, 7), device=device) + assert torch.all(wp.to_torch(articulation_data.body_com_pose_w) == expected) + + for i in range(5): + articulation_data._sim_timestamp = i + 1.0 + + # Generate random link poses with normalized quaternions + link_pose = torch.zeros((num_instances, num_bodies, 7), device=device) + link_pose[..., :3] = torch.rand((num_instances, num_bodies, 3), device=device) + link_pose[..., 3:] = torch.randn((num_instances, num_bodies, 4), device=device) + link_pose[..., 3:] = torch.nn.functional.normalize(link_pose[..., 3:], p=2.0, dim=-1, eps=1e-12) + + # Generate random body COM position in body frame + body_com_pos = torch.rand((num_instances, num_bodies, 3), device=device) + + mock_view.set_mock_data( + link_transforms=wp.from_torch(link_pose, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # Compute expected COM pose using IsaacLab reference implementation + # combine_frame_transforms(link_pos, link_quat, com_pos_b, identity_quat) + body_com_quat_b = torch.zeros((num_instances, num_bodies, 4), device=device) + body_com_quat_b[..., 3] = 1.0 # identity quaternion + + expected_pos, expected_quat = math_utils.combine_frame_transforms( + link_pose[..., :3], link_pose[..., 3:], body_com_pos, body_com_quat_b + ) + expected_pose = torch.cat([expected_pos, expected_quat], dim=-1) + + # Compare the computed value + assert torch.allclose(wp.to_torch(articulation_data.body_com_pose_w), expected_pose, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_timestamp_invalidation(self, mock_newton_manager, device: str): + """Test that data is invalidated when timestamp is updated.""" + articulation_data, mock_view = self._setup_method(1, 1, device) + + # Check initial timestamp + assert articulation_data._body_com_pose_w.timestamp == -1.0 + assert articulation_data._sim_timestamp == 0.0 + + # Request the property to trigger computation + value = wp.to_torch(articulation_data.body_com_pose_w).clone() + + # Check that buffer timestamp matches sim timestamp + assert articulation_data._body_com_pose_w.timestamp == articulation_data._sim_timestamp + + # Update mock data without changing sim timestamp + mock_view.set_mock_data( + link_transforms=wp.from_torch(torch.rand((1, 1, 7), device=device), dtype=wp.transformf), + ) + + # Value should NOT change (cached) + assert torch.all(wp.to_torch(articulation_data.body_com_pose_w) == value) + + # Update sim timestamp + articulation_data._sim_timestamp = 1.0 + + # Buffer timestamp should now be stale + assert articulation_data._body_com_pose_w.timestamp != articulation_data._sim_timestamp + + # Value should now be recomputed (different from cached) + assert not torch.all(wp.to_torch(articulation_data.body_com_pose_w) == value) + + +class TestBodyComVelW: + """Tests the body center of mass velocity property. + + This value is read directly from the simulation bindings. + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is a reference to the internal data. + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_body_com_vel_w(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that body_com_vel_w has correct type, shape, and reference behavior.""" + articulation_data, _ = self._setup_method(num_instances, num_bodies, device) + + # Check the type and shape + assert articulation_data.body_com_vel_w.shape == (num_instances, num_bodies) + assert articulation_data.body_com_vel_w.dtype == wp.spatial_vectorf + + # Mock data is initialized to zeros + expected = torch.zeros((num_instances, num_bodies, 6), device=device) + assert torch.all(wp.to_torch(articulation_data.body_com_vel_w) == expected) + + # Get the property reference + body_com_vel_ref = articulation_data.body_com_vel_w + + # Assign a different value via property + articulation_data.body_com_vel_w.fill_(1.0) + + # Check that the property returns the new value (reference behavior) + expected_ones = torch.ones((num_instances, num_bodies, 6), device=device) + assert torch.all(wp.to_torch(articulation_data.body_com_vel_w) == expected_ones) + + # Assign a different value via reference + body_com_vel_ref.fill_(2.0) + + # Check that the internal data has been updated + expected_twos = torch.ones((num_instances, num_bodies, 6), device=device) * 2.0 + assert torch.all(wp.to_torch(articulation_data.body_com_vel_w) == expected_twos) + + +class TestBodyState: + """Tests the body state properties. + + Test the body state properties are correctly updated from the pose and velocity properties. + Tests the following properties: + - body_state_w + - body_link_state_w + - body_com_state_w + + For each property, we run the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is correctly assembled from pose and velocity. + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_all_body_state_properties(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that all body state properties correctly combine pose and velocity.""" + articulation_data, mock_view = self._setup_method(num_instances, num_bodies, device) + + # Generate random mock data + for i in range(5): + articulation_data._sim_timestamp = i + 1.0 + + # Generate random body link pose with normalized quaternions + body_link_pose = torch.zeros((num_instances, num_bodies, 7), device=device) + body_link_pose[..., :3] = torch.rand((num_instances, num_bodies, 3), device=device) + body_link_pose[..., 3:] = torch.randn((num_instances, num_bodies, 4), device=device) + body_link_pose[..., 3:] = torch.nn.functional.normalize(body_link_pose[..., 3:], p=2.0, dim=-1, eps=1e-12) + + # Generate random COM velocities and COM position + com_vel = torch.rand((num_instances, num_bodies, 6), device=device) + body_com_pos = torch.rand((num_instances, num_bodies, 3), device=device) + + mock_view.set_mock_data( + link_transforms=wp.from_torch(body_link_pose, dtype=wp.transformf), + link_velocities=wp.from_torch(com_vel, dtype=wp.spatial_vectorf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # --- Test body_state_w --- + # Combines body_link_pose_w with body_com_vel_w + body_state = wp.to_torch(articulation_data.body_state_w) + expected_body_state = torch.cat([body_link_pose, com_vel], dim=-1) + + assert body_state.shape == (num_instances, num_bodies, 13) + assert torch.allclose(body_state, expected_body_state, atol=1e-6, rtol=1e-6) + + # --- Test body_link_state_w --- + # Combines body_link_pose_w with body_link_vel_w + body_link_state = wp.to_torch(articulation_data.body_link_state_w) + + # Compute expected body_link_vel from com_vel (same as TestBodyLinkVelW) + body_link_vel = com_vel.clone() + body_link_vel[..., :3] += torch.linalg.cross( + body_link_vel[..., 3:], + math_utils.quat_apply(body_link_pose[..., 3:], -body_com_pos), + dim=-1, + ) + expected_body_link_state = torch.cat([body_link_pose, body_link_vel], dim=-1) + + assert body_link_state.shape == (num_instances, num_bodies, 13) + assert torch.allclose(body_link_state, expected_body_link_state, atol=1e-6, rtol=1e-6) + + # --- Test body_com_state_w --- + # Combines body_com_pose_w with body_com_vel_w + body_com_state = wp.to_torch(articulation_data.body_com_state_w) + + # Compute expected body_com_pose from body_link_pose and body_com_pos (same as TestBodyComPoseW) + body_com_quat_b = torch.zeros((num_instances, num_bodies, 4), device=device) + body_com_quat_b[..., 3] = 1.0 + body_com_pos_w, body_com_quat_w = math_utils.combine_frame_transforms( + body_link_pose[..., :3], body_link_pose[..., 3:], body_com_pos, body_com_quat_b + ) + expected_body_com_state = torch.cat([body_com_pos_w, body_com_quat_w, com_vel], dim=-1) + + assert body_com_state.shape == (num_instances, num_bodies, 13) + assert torch.allclose(body_com_state, expected_body_com_state, atol=1e-6, rtol=1e-6) + + +class TestBodyComAccW: + """Tests the body center of mass acceleration property. + + This value is derived from velocity finite differencing: (current_vel - previous_vel) / dt + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is correctly computed. + - Checks that the timestamp is updated correctly. + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str, initial_vel: torch.Tensor | None = None + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + + # Set initial velocities (these become _previous_body_com_vel) + if initial_vel is not None: + mock_view.set_mock_data( + link_velocities=wp.from_torch(initial_vel, dtype=wp.spatial_vectorf), + ) + else: + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_correctness(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that body_com_acc_w is correctly computed from velocity finite differencing.""" + # Initial velocity (becomes previous_velocity) + previous_vel = torch.rand((num_instances, num_bodies, 6), device=device) + articulation_data, mock_view = self._setup_method(num_instances, num_bodies, device, previous_vel) + + # Check the type and shape + assert articulation_data.body_com_acc_w.shape == (num_instances, num_bodies) + assert articulation_data.body_com_acc_w.dtype == wp.spatial_vectorf + + # dt is mocked as 0.01 + dt = 0.01 + + for i in range(10): + articulation_data._sim_timestamp = i + 1.0 + + # Generate new random velocity + current_vel = torch.rand((num_instances, num_bodies, 6), device=device) + mock_view.set_mock_data( + link_velocities=wp.from_torch(current_vel, dtype=wp.spatial_vectorf), + ) + + # Compute expected acceleration: (current - previous) / dt + expected_acc = (current_vel - previous_vel) / dt + + # Compare the computed value + assert torch.allclose(wp.to_torch(articulation_data.body_com_acc_w), expected_acc, atol=1e-5, rtol=1e-5) + # Update previous velocity + previous_vel = current_vel.clone() + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_timestamp_invalidation(self, mock_newton_manager, device: str): + """Test that data is invalidated when timestamp is updated.""" + initial_vel = torch.zeros((1, 1, 6), device=device) + articulation_data, mock_view = self._setup_method(1, 1, device, initial_vel) + + # Check initial timestamp + assert articulation_data._body_com_acc_w.timestamp == -1.0 + assert articulation_data._sim_timestamp == 0.0 + + # Request the property to trigger computation + value = wp.to_torch(articulation_data.body_com_acc_w).clone() + + # Check that buffer timestamp matches sim timestamp + assert articulation_data._body_com_acc_w.timestamp == articulation_data._sim_timestamp + + # Update mock data without changing sim timestamp + mock_view.set_mock_data( + link_velocities=wp.from_torch(torch.rand((1, 1, 6), device=device), dtype=wp.spatial_vectorf), + ) + + # Value should NOT change (cached) + assert torch.all(wp.to_torch(articulation_data.body_com_acc_w) == value) + + # Update sim timestamp + articulation_data._sim_timestamp = 1.0 + + # Buffer timestamp should now be stale + assert articulation_data._body_com_acc_w.timestamp != articulation_data._sim_timestamp + + # Value should now be recomputed (different from cached) + assert not torch.all(wp.to_torch(articulation_data.body_com_acc_w) == value) + + +class TestBodyComPoseB: + """Tests the body center of mass pose in body frame property. + + This value is generated from COM position with identity quaternion. + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value correctly combines position with identity quaternion. + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_body_com_pose_b(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that body_com_pose_b correctly generates pose from position with identity quaternion.""" + articulation_data, mock_view = self._setup_method(num_instances, num_bodies, device) + + # Check the type and shape + assert articulation_data.body_com_pose_b.shape == (num_instances, num_bodies) + assert articulation_data.body_com_pose_b.dtype == wp.transformf + + # Mock data is initialized to zeros for COM position + # Expected pose: [0, 0, 0, 0, 0, 0, 1] (position zeros, identity quaternion) + expected = torch.zeros((num_instances, num_bodies, 7), device=device) + expected[..., 6] = 1.0 # w component of identity quaternion + assert torch.all(wp.to_torch(articulation_data.body_com_pose_b) == expected) + + # Update COM position and verify + com_pos = torch.rand((num_instances, num_bodies, 3), device=device) + mock_view.set_mock_data( + body_com_pos=wp.from_torch(com_pos, dtype=wp.vec3f), + ) + + # Get the pose + pose = wp.to_torch(articulation_data.body_com_pose_b) + + # Expected: position from mock, identity quaternion + expected_pose = torch.zeros((num_instances, num_bodies, 7), device=device) + expected_pose[..., :3] = com_pos + expected_pose[..., 6] = 1.0 # w component + + assert torch.allclose(pose, expected_pose, atol=1e-6, rtol=1e-6) + + +# TODO: Update this test when body_incoming_joint_wrench_b support is added to Newton. +class TestBodyIncomingJointWrenchB: + """Tests the body incoming joint wrench property. + + Currently, this property raises NotImplementedError as joint wrenches + are not supported in Newton. + + Runs the following checks: + - Checks that the property raises NotImplementedError. + """ + + def _setup_method(self, num_instances: int, num_bodies: int, device: str) -> ArticulationData: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_not_implemented(self, mock_newton_manager, device: str): + """Test that body_incoming_joint_wrench_b raises NotImplementedError.""" + articulation_data = self._setup_method(1, 1, device) + + with pytest.raises(NotImplementedError): + _ = articulation_data.body_incoming_joint_wrench_b + + +## +# Test Cases -- Joint state properties. +## + + +class TestJointPosVel: + """Tests the joint position and velocity properties. + + These values are read directly from the simulation bindings. + + Tests the following properties: + - joint_pos + - joint_vel + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is a reference to the internal data. + """ + + def _setup_method( + self, num_instances: int, num_joints: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, num_joints, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_joint_pos_and_vel(self, mock_newton_manager, num_instances: int, num_joints: int, device: str): + """Test that joint_pos and joint_vel have correct type, shape, and reference behavior.""" + articulation_data, mock_view = self._setup_method(num_instances, num_joints, device) + + # --- Test joint_pos --- + # Check the type and shape + assert articulation_data.joint_pos.shape == (num_instances, num_joints) + assert articulation_data.joint_pos.dtype == wp.float32 + + # Mock data is initialized to zeros + expected = torch.zeros((num_instances, num_joints), device=device) + assert torch.all(wp.to_torch(articulation_data.joint_pos) == expected) + + # Get the property reference + joint_pos_ref = articulation_data.joint_pos + + # Assign a different value via property + articulation_data.joint_pos.fill_(1.0) + + # Check that the property returns the new value (reference behavior) + expected_ones = torch.ones((num_instances, num_joints), device=device) + assert torch.all(wp.to_torch(articulation_data.joint_pos) == expected_ones) + + # Assign a different value via reference + joint_pos_ref.fill_(2.0) + + # Check that the internal data has been updated + expected_twos = torch.ones((num_instances, num_joints), device=device) * 2.0 + assert torch.all(wp.to_torch(articulation_data.joint_pos) == expected_twos) + + # --- Test joint_vel --- + # Check the type and shape + assert articulation_data.joint_vel.shape == (num_instances, num_joints) + assert articulation_data.joint_vel.dtype == wp.float32 + + # Mock data is initialized to zeros + expected = torch.zeros((num_instances, num_joints), device=device) + assert torch.all(wp.to_torch(articulation_data.joint_vel) == expected) + + # Get the property reference + joint_vel_ref = articulation_data.joint_vel + + # Assign a different value via property + articulation_data.joint_vel.fill_(1.0) + + # Check that the property returns the new value (reference behavior) + expected_ones = torch.ones((num_instances, num_joints), device=device) + assert torch.all(wp.to_torch(articulation_data.joint_vel) == expected_ones) + + # Assign a different value via reference + joint_vel_ref.fill_(2.0) + + # Check that the internal data has been updated + expected_twos = torch.ones((num_instances, num_joints), device=device) * 2.0 + assert torch.all(wp.to_torch(articulation_data.joint_vel) == expected_twos) + + +class TestJointAcc: + """Tests the joint acceleration property. + + This value is derived from velocity finite differencing: (current_vel - previous_vel) / dt + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is correctly computed. + - Checks that the timestamp is updated correctly. + """ + + def _setup_method( + self, num_instances: int, num_joints: int, device: str, initial_vel: torch.Tensor | None = None + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, num_joints, device) + + # Set initial velocities (these become _previous_joint_vel) + if initial_vel is not None: + mock_view.set_mock_data( + dof_velocities=wp.from_torch(initial_vel, dtype=wp.float32), + ) + else: + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_correctness(self, mock_newton_manager, num_instances: int, num_joints: int, device: str): + """Test that joint_acc is correctly computed from velocity finite differencing.""" + # Initial velocity (becomes previous_velocity) + previous_vel = torch.rand((num_instances, num_joints), device=device) + articulation_data, mock_view = self._setup_method(num_instances, num_joints, device, previous_vel) + + # Check the type and shape + assert articulation_data.joint_acc.shape == (num_instances, num_joints) + assert articulation_data.joint_acc.dtype == wp.float32 + + # dt is mocked as 0.01 + dt = 0.01 + + for i in range(5): + articulation_data._sim_timestamp = i + 1.0 + + # Generate new random velocity + current_vel = torch.rand((num_instances, num_joints), device=device) + mock_view.set_mock_data( + dof_velocities=wp.from_torch(current_vel, dtype=wp.float32), + ) + + # Compute expected acceleration: (current - previous) / dt + expected_acc = (current_vel - previous_vel) / dt + + # Compare the computed value + assert torch.allclose(wp.to_torch(articulation_data.joint_acc), expected_acc, atol=1e-5, rtol=1e-5) + # Update previous velocity + previous_vel = current_vel.clone() + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_timestamp_invalidation(self, mock_newton_manager, device: str): + """Test that data is invalidated when timestamp is updated.""" + initial_vel = torch.zeros((1, 1), device=device) + articulation_data, mock_view = self._setup_method(1, 1, device, initial_vel) + + # Check initial timestamp + assert articulation_data._joint_acc.timestamp == -1.0 + assert articulation_data._sim_timestamp == 0.0 + + # Request the property to trigger computation + value = wp.to_torch(articulation_data.joint_acc).clone() + + # Check that buffer timestamp matches sim timestamp + assert articulation_data._joint_acc.timestamp == articulation_data._sim_timestamp + + # Update mock data without changing sim timestamp + mock_view.set_mock_data( + dof_velocities=wp.from_torch(torch.rand((1, 1), device=device), dtype=wp.float32), + ) + + # Value should NOT change (cached) + assert torch.all(wp.to_torch(articulation_data.joint_acc) == value) + + # Update sim timestamp + articulation_data._sim_timestamp = 1.0 + + # Buffer timestamp should now be stale + assert articulation_data._joint_acc.timestamp != articulation_data._sim_timestamp + + # Value should now be recomputed (different from cached) + assert not torch.all(wp.to_torch(articulation_data.joint_acc) == value) + + +## +# Test Cases -- Derived properties. +## + + +class TestProjectedGravityB: + """Tests the projected gravity in body frame property. + + This value is derived by projecting the gravity vector onto the body frame. + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is correctly computed. + - Checks that the timestamp is updated correctly. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_correctness(self, mock_newton_manager, num_instances: int, device: str): + """Test that projected_gravity_b is correctly computed.""" + articulation_data, mock_view = self._setup_method(num_instances, device) + + # Check the type and shape + assert articulation_data.projected_gravity_b.shape == (num_instances,) + assert articulation_data.projected_gravity_b.dtype == wp.vec3f + + # Gravity direction (normalized) + gravity_dir = torch.tensor([0.0, 0.0, -1.0], device=device) + + for i in range(10): + articulation_data._sim_timestamp = i + 1.0 + # Generate random root pose with normalized quaternion + root_pose = torch.zeros((num_instances, 7), device=device) + root_pose[:, :3] = torch.rand((num_instances, 3), device=device) + root_pose[:, 3:] = torch.randn((num_instances, 4), device=device) + root_pose[:, 3:] = torch.nn.functional.normalize(root_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_pose, dtype=wp.transformf), + ) + + # Compute expected projected gravity: quat_apply(quat, gravity_dir) + # This rotates gravity from world to body frame + expected = math_utils.quat_apply_inverse(root_pose[:, 3:], gravity_dir.expand(num_instances, 3)) + + # Compare the computed value + assert torch.allclose(wp.to_torch(articulation_data.projected_gravity_b), expected, atol=1e-4, rtol=1e-4) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_timestamp_invalidation(self, mock_newton_manager, device: str): + """Test that data is invalidated when timestamp is updated.""" + articulation_data, mock_view = self._setup_method(1, device) + + # Check initial timestamp + assert articulation_data._projected_gravity_b.timestamp == -1.0 + assert articulation_data._sim_timestamp == 0.0 + + # Request the property to trigger computation + value = wp.to_torch(articulation_data.projected_gravity_b).clone() + + # Check that buffer timestamp matches sim timestamp + assert articulation_data._projected_gravity_b.timestamp == articulation_data._sim_timestamp + + # Update mock data without changing sim timestamp + new_pose = torch.zeros((1, 7), device=device) + new_pose[:, 3:] = torch.randn((1, 4), device=device) + new_pose[:, 3:] = torch.nn.functional.normalize(new_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + mock_view.set_mock_data( + root_transforms=wp.from_torch(new_pose, dtype=wp.transformf), + ) + + # Value should NOT change (cached) + assert torch.all(wp.to_torch(articulation_data.projected_gravity_b) == value) + + # Update sim timestamp + articulation_data._sim_timestamp = 1.0 + + # Buffer timestamp should now be stale + assert articulation_data._projected_gravity_b.timestamp != articulation_data._sim_timestamp + + # Value should now be recomputed (different from cached) + assert not torch.all(wp.to_torch(articulation_data.projected_gravity_b) == value) + + +class TestHeadingW: + """Tests the heading in world frame property. + + This value is derived by computing the yaw angle from the forward direction. + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is correctly computed. + - Checks that the timestamp is updated correctly. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_correctness(self, mock_newton_manager, num_instances: int, device: str): + """Test that heading_w is correctly computed.""" + articulation_data, mock_view = self._setup_method(num_instances, device) + + # Check the type and shape + assert articulation_data.heading_w.shape == (num_instances,) + assert articulation_data.heading_w.dtype == wp.float32 + + # Forward direction in body frame + forward_vec_b = torch.tensor([1.0, 0.0, 0.0], device=device) + + for i in range(10): + articulation_data._sim_timestamp = i + 1.0 + # Generate random root pose with normalized quaternion + root_pose = torch.zeros((num_instances, 7), device=device) + root_pose[:, :3] = torch.rand((num_instances, 3), device=device) + root_pose[:, 3:] = torch.randn((num_instances, 4), device=device) + root_pose[:, 3:] = torch.nn.functional.normalize(root_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_pose, dtype=wp.transformf), + ) + print(articulation_data._sim_bind_root_link_pose_w) + print(articulation_data.FORWARD_VEC_B) + # Compute expected heading: atan2(rotated_forward.y, rotated_forward.x) + rotated_forward = math_utils.quat_apply(root_pose[:, 3:], forward_vec_b.expand(num_instances, 3)) + expected = torch.atan2(rotated_forward[:, 1], rotated_forward[:, 0]) + print(f"expected: {expected}") + + # Compare the computed value + assert torch.allclose(wp.to_torch(articulation_data.heading_w), expected, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_timestamp_invalidation(self, mock_newton_manager, device: str): + """Test that data is invalidated when timestamp is updated.""" + articulation_data, mock_view = self._setup_method(1, device) + + # Check initial timestamp + assert articulation_data._heading_w.timestamp == -1.0 + assert articulation_data._sim_timestamp == 0.0 + + # Request the property to trigger computation + value = wp.to_torch(articulation_data.heading_w).clone() + + # Check that buffer timestamp matches sim timestamp + assert articulation_data._heading_w.timestamp == articulation_data._sim_timestamp + + # Update mock data without changing sim timestamp + new_pose = torch.zeros((1, 7), device=device) + new_pose[:, 3:] = torch.randn((1, 4), device=device) + new_pose[:, 3:] = torch.nn.functional.normalize(new_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + mock_view.set_mock_data( + root_transforms=wp.from_torch(new_pose, dtype=wp.transformf), + ) + + # Value should NOT change (cached) + assert torch.all(wp.to_torch(articulation_data.heading_w) == value) + + # Update sim timestamp + articulation_data._sim_timestamp = 1.0 + + # Buffer timestamp should now be stale + assert articulation_data._heading_w.timestamp != articulation_data._sim_timestamp + + # Value should now be recomputed (different from cached) + assert not torch.all(wp.to_torch(articulation_data.heading_w) == value) + + +class TestRootLinkVelB: + """Tests the root link velocity in body frame properties. + + Tests the following properties: + - root_link_vel_b: velocity projected to body frame + - root_link_lin_vel_b: linear velocity slice (first 3 components) + - root_link_ang_vel_b: angular velocity slice (last 3 components) + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is correctly computed. + - Checks that lin/ang velocities are correct slices. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_correctness(self, mock_newton_manager, num_instances: int, device: str): + """Test that root_link_vel_b and its slices are correctly computed.""" + articulation_data, mock_view = self._setup_method(num_instances, device) + + # Check types and shapes + assert articulation_data.root_link_vel_b.shape == (num_instances,) + assert articulation_data.root_link_vel_b.dtype == wp.spatial_vectorf + + assert articulation_data.root_link_lin_vel_b.shape == (num_instances,) + assert articulation_data.root_link_lin_vel_b.dtype == wp.vec3f + + assert articulation_data.root_link_ang_vel_b.shape == (num_instances,) + assert articulation_data.root_link_ang_vel_b.dtype == wp.vec3f + + for i in range(5): + articulation_data._sim_timestamp = i + 1.0 + + # Generate random root pose with normalized quaternion + root_pose = torch.zeros((num_instances, 7), device=device) + root_pose[:, :3] = torch.rand((num_instances, 3), device=device) + root_pose[:, 3:] = torch.randn((num_instances, 4), device=device) + root_pose[:, 3:] = torch.nn.functional.normalize(root_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + + # Generate random COM velocity and body COM position + com_vel = torch.rand((num_instances, 6), device=device) + body_com_pos = torch.rand((num_instances, 1, 3), device=device) + + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_pose, dtype=wp.transformf), + root_velocities=wp.from_torch(com_vel, dtype=wp.spatial_vectorf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # Compute expected root_link_vel_w (same as TestRootLinkVelW) + root_link_vel_w = com_vel.clone() + root_link_vel_w[:, :3] += torch.linalg.cross( + root_link_vel_w[:, 3:], + math_utils.quat_apply(root_pose[:, 3:], -body_com_pos[:, 0]), + dim=-1, + ) + + # Project to body frame using quat_rotate_inv + # Linear velocity: quat_rotate_inv(quat, lin_vel) + # Angular velocity: quat_rotate_inv(quat, ang_vel) + lin_vel_b = math_utils.quat_apply_inverse(root_pose[:, 3:], root_link_vel_w[:, :3]) + ang_vel_b = math_utils.quat_apply_inverse(root_pose[:, 3:], root_link_vel_w[:, 3:]) + expected_vel_b = torch.cat([lin_vel_b, ang_vel_b], dim=-1) + + # Get computed values + computed_vel_b = wp.to_torch(articulation_data.root_link_vel_b) + computed_lin_vel_b = wp.to_torch(articulation_data.root_link_lin_vel_b) + computed_ang_vel_b = wp.to_torch(articulation_data.root_link_ang_vel_b) + + # Compare full velocity + assert torch.allclose(computed_vel_b, expected_vel_b, atol=1e-6, rtol=1e-6) + + # Check that lin/ang velocities are correct slices + assert torch.allclose(computed_lin_vel_b, computed_vel_b[:, :3], atol=1e-6, rtol=1e-6) + assert torch.allclose(computed_ang_vel_b, computed_vel_b[:, 3:], atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_timestamp_invalidation(self, mock_newton_manager, device: str): + """Test that data is invalidated when timestamp is updated.""" + articulation_data, mock_view = self._setup_method(1, device) + + # Check initial timestamp + assert articulation_data._root_link_vel_b.timestamp == -1.0 + assert articulation_data._sim_timestamp == 0.0 + + # Request the property to trigger computation + value = wp.to_torch(articulation_data.root_link_vel_b).clone() + + # Check that buffer timestamp matches sim timestamp + assert articulation_data._root_link_vel_b.timestamp == articulation_data._sim_timestamp + + # Update mock data without changing sim timestamp + new_pose = torch.zeros((1, 7), device=device) + new_pose[:, 3:] = torch.randn((1, 4), device=device) + new_pose[:, 3:] = torch.nn.functional.normalize(new_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + mock_view.set_mock_data( + root_transforms=wp.from_torch(new_pose, dtype=wp.transformf), + root_velocities=wp.from_torch(torch.rand((1, 6), device=device), dtype=wp.spatial_vectorf), + ) + + # Value should NOT change (cached) + assert torch.all(wp.to_torch(articulation_data.root_link_vel_b) == value) + + # Update sim timestamp + articulation_data._sim_timestamp = 1.0 + + # Buffer timestamp should now be stale + assert articulation_data._root_link_vel_b.timestamp != articulation_data._sim_timestamp + + # Value should now be recomputed (different from cached) + assert not torch.all(wp.to_torch(articulation_data.root_link_vel_b) == value) + + +class TestRootComVelB: + """Tests the root center of mass velocity in body frame properties. + + Tests the following properties: + - root_com_vel_b: COM velocity projected to body frame + - root_com_lin_vel_b: linear velocity slice (first 3 components) + - root_com_ang_vel_b: angular velocity slice (last 3 components) + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is correctly computed. + - Checks that lin/ang velocities are correct slices. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_correctness(self, mock_newton_manager, num_instances: int, device: str): + """Test that root_com_vel_b and its slices are correctly computed.""" + articulation_data, mock_view = self._setup_method(num_instances, device) + + # Check types and shapes + assert articulation_data.root_com_vel_b.shape == (num_instances,) + assert articulation_data.root_com_vel_b.dtype == wp.spatial_vectorf + + assert articulation_data.root_com_lin_vel_b.shape == (num_instances,) + assert articulation_data.root_com_lin_vel_b.dtype == wp.vec3f + + assert articulation_data.root_com_ang_vel_b.shape == (num_instances,) + assert articulation_data.root_com_ang_vel_b.dtype == wp.vec3f + + for i in range(5): + articulation_data._sim_timestamp = i + 1.0 + + # Generate random root pose with normalized quaternion + root_pose = torch.zeros((num_instances, 7), device=device) + root_pose[:, :3] = torch.rand((num_instances, 3), device=device) + root_pose[:, 3:] = torch.randn((num_instances, 4), device=device) + root_pose[:, 3:] = torch.nn.functional.normalize(root_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + + # Generate random COM velocity (this is root_com_vel_w from simulation) + com_vel_w = torch.rand((num_instances, 6), device=device) + + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_pose, dtype=wp.transformf), + root_velocities=wp.from_torch(com_vel_w, dtype=wp.spatial_vectorf), + ) + + # Project COM velocity to body frame using quat_rotate_inv (quat_conjugate + quat_apply) + lin_vel_b = math_utils.quat_apply_inverse(root_pose[:, 3:], com_vel_w[:, :3]) + ang_vel_b = math_utils.quat_apply_inverse(root_pose[:, 3:], com_vel_w[:, 3:]) + expected_vel_b = torch.cat([lin_vel_b, ang_vel_b], dim=-1) + + # Get computed values + computed_vel_b = wp.to_torch(articulation_data.root_com_vel_b) + computed_lin_vel_b = wp.to_torch(articulation_data.root_com_lin_vel_b) + computed_ang_vel_b = wp.to_torch(articulation_data.root_com_ang_vel_b) + + # Compare full velocity + assert torch.allclose(computed_vel_b, expected_vel_b, atol=1e-6, rtol=1e-6) + + # Check that lin/ang velocities are correct slices + assert torch.allclose(computed_lin_vel_b, computed_vel_b[:, :3], atol=1e-6, rtol=1e-6) + assert torch.allclose(computed_ang_vel_b, computed_vel_b[:, 3:], atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_timestamp_invalidation(self, mock_newton_manager, device: str): + """Test that data is invalidated when timestamp is updated.""" + articulation_data, mock_view = self._setup_method(1, device) + + # Check initial timestamp + assert articulation_data._root_com_vel_b.timestamp == -1.0 + assert articulation_data._sim_timestamp == 0.0 + + # Request the property to trigger computation + value = wp.to_torch(articulation_data.root_com_vel_b).clone() + + # Check that buffer timestamp matches sim timestamp + assert articulation_data._root_com_vel_b.timestamp == articulation_data._sim_timestamp + + # Update mock data without changing sim timestamp + new_pose = torch.zeros((1, 7), device=device) + new_pose[:, 3:] = torch.randn((1, 4), device=device) + new_pose[:, 3:] = torch.nn.functional.normalize(new_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + mock_view.set_mock_data( + root_transforms=wp.from_torch(new_pose, dtype=wp.transformf), + root_velocities=wp.from_torch(torch.rand((1, 6), device=device), dtype=wp.spatial_vectorf), + ) + + # Value should NOT change (cached) + assert torch.all(wp.to_torch(articulation_data.root_com_vel_b) == value) + + # Update sim timestamp + articulation_data._sim_timestamp = 1.0 + + # Buffer timestamp should now be stale + assert articulation_data._root_com_vel_b.timestamp != articulation_data._sim_timestamp + + # Value should now be recomputed (different from cached) + assert not torch.all(wp.to_torch(articulation_data.root_com_vel_b) == value) + + +## +# Test Cases -- Sliced properties. +## + + +class TestRootSlicedProperties: + """Tests the root sliced properties. + + These properties extract position, quaternion, linear velocity, or angular velocity + from the full pose/velocity arrays. + + Tests the following properties: + - root_link_pos_w, root_link_quat_w (from root_link_pose_w) + - root_link_lin_vel_w, root_link_ang_vel_w (from root_link_vel_w) + - root_com_pos_w, root_com_quat_w (from root_com_pose_w) + - root_com_lin_vel_w, root_com_ang_vel_w (from root_com_vel_w) + + For each property, we only check that they are the correct slice of the parent property. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_all_root_sliced_properties(self, mock_newton_manager, num_instances: int, device: str): + """Test that all root sliced properties are correct slices of their parent properties.""" + articulation_data, mock_view = self._setup_method(num_instances, device) + + # Set up random mock data to ensure non-trivial values + articulation_data._sim_timestamp = 1.0 + + root_pose = torch.zeros((num_instances, 7), device=device) + root_pose[:, :3] = torch.rand((num_instances, 3), device=device) + root_pose[:, 3:] = torch.randn((num_instances, 4), device=device) + root_pose[:, 3:] = torch.nn.functional.normalize(root_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + + com_vel = torch.rand((num_instances, 6), device=device) + body_com_pos = torch.rand((num_instances, 1, 3), device=device) + + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_pose, dtype=wp.transformf), + root_velocities=wp.from_torch(com_vel, dtype=wp.spatial_vectorf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # --- Test root_link_pose_w slices --- + root_link_pose = wp.to_torch(articulation_data.root_link_pose_w) + root_link_pos = wp.to_torch(articulation_data.root_link_pos_w) + root_link_quat = wp.to_torch(articulation_data.root_link_quat_w) + + assert root_link_pos.shape == (num_instances, 3) + assert root_link_quat.shape == (num_instances, 4) + assert torch.allclose(root_link_pos, root_link_pose[:, :3], atol=1e-6) + assert torch.allclose(root_link_quat, root_link_pose[:, 3:], atol=1e-6) + + # --- Test root_link_vel_w slices --- + root_link_vel = wp.to_torch(articulation_data.root_link_vel_w) + root_link_lin_vel = wp.to_torch(articulation_data.root_link_lin_vel_w) + root_link_ang_vel = wp.to_torch(articulation_data.root_link_ang_vel_w) + + assert root_link_lin_vel.shape == (num_instances, 3) + assert root_link_ang_vel.shape == (num_instances, 3) + assert torch.allclose(root_link_lin_vel, root_link_vel[:, :3], atol=1e-6) + assert torch.allclose(root_link_ang_vel, root_link_vel[:, 3:], atol=1e-6) + + # --- Test root_com_pose_w slices --- + root_com_pose = wp.to_torch(articulation_data.root_com_pose_w) + root_com_pos = wp.to_torch(articulation_data.root_com_pos_w) + root_com_quat = wp.to_torch(articulation_data.root_com_quat_w) + + assert root_com_pos.shape == (num_instances, 3) + assert root_com_quat.shape == (num_instances, 4) + assert torch.allclose(root_com_pos, root_com_pose[:, :3], atol=1e-6) + assert torch.allclose(root_com_quat, root_com_pose[:, 3:], atol=1e-6) + + # --- Test root_com_vel_w slices --- + root_com_vel = wp.to_torch(articulation_data.root_com_vel_w) + root_com_lin_vel = wp.to_torch(articulation_data.root_com_lin_vel_w) + root_com_ang_vel = wp.to_torch(articulation_data.root_com_ang_vel_w) + + assert root_com_lin_vel.shape == (num_instances, 3) + assert root_com_ang_vel.shape == (num_instances, 3) + assert torch.allclose(root_com_lin_vel, root_com_vel[:, :3], atol=1e-6) + assert torch.allclose(root_com_ang_vel, root_com_vel[:, 3:], atol=1e-6) + + +class TestBodySlicedProperties: + """Tests the body sliced properties. + + These properties extract position, quaternion, linear velocity, or angular velocity + from the full pose/velocity arrays. + + Tests the following properties: + - body_link_pos_w, body_link_quat_w (from body_link_pose_w) + - body_link_lin_vel_w, body_link_ang_vel_w (from body_link_vel_w) + - body_com_pos_w, body_com_quat_w (from body_com_pose_w) + - body_com_lin_vel_w, body_com_ang_vel_w (from body_com_vel_w) + + For each property, we only check that they are the correct slice of the parent property. + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_all_body_sliced_properties(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that all body sliced properties are correct slices of their parent properties.""" + articulation_data, mock_view = self._setup_method(num_instances, num_bodies, device) + + # Set up random mock data to ensure non-trivial values + articulation_data._sim_timestamp = 1.0 + + body_pose = torch.zeros((num_instances, num_bodies, 7), device=device) + body_pose[..., :3] = torch.rand((num_instances, num_bodies, 3), device=device) + body_pose[..., 3:] = torch.randn((num_instances, num_bodies, 4), device=device) + body_pose[..., 3:] = torch.nn.functional.normalize(body_pose[..., 3:], p=2.0, dim=-1, eps=1e-12) + + body_vel = torch.rand((num_instances, num_bodies, 6), device=device) + body_com_pos = torch.rand((num_instances, num_bodies, 3), device=device) + + mock_view.set_mock_data( + link_transforms=wp.from_torch(body_pose, dtype=wp.transformf), + link_velocities=wp.from_torch(body_vel, dtype=wp.spatial_vectorf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # --- Test body_link_pose_w slices --- + body_link_pose = wp.to_torch(articulation_data.body_link_pose_w) + body_link_pos = wp.to_torch(articulation_data.body_link_pos_w) + body_link_quat = wp.to_torch(articulation_data.body_link_quat_w) + + assert body_link_pos.shape == (num_instances, num_bodies, 3) + assert body_link_quat.shape == (num_instances, num_bodies, 4) + assert torch.allclose(body_link_pos, body_link_pose[..., :3], atol=1e-6) + assert torch.allclose(body_link_quat, body_link_pose[..., 3:], atol=1e-6) + + # --- Test body_link_vel_w slices --- + body_link_vel = wp.to_torch(articulation_data.body_link_vel_w) + body_link_lin_vel = wp.to_torch(articulation_data.body_link_lin_vel_w) + body_link_ang_vel = wp.to_torch(articulation_data.body_link_ang_vel_w) + + assert body_link_lin_vel.shape == (num_instances, num_bodies, 3) + assert body_link_ang_vel.shape == (num_instances, num_bodies, 3) + assert torch.allclose(body_link_lin_vel, body_link_vel[..., :3], atol=1e-6) + assert torch.allclose(body_link_ang_vel, body_link_vel[..., 3:], atol=1e-6) + + # --- Test body_com_pose_w slices --- + body_com_pose = wp.to_torch(articulation_data.body_com_pose_w) + body_com_pos_w = wp.to_torch(articulation_data.body_com_pos_w) + body_com_quat_w = wp.to_torch(articulation_data.body_com_quat_w) + + assert body_com_pos_w.shape == (num_instances, num_bodies, 3) + assert body_com_quat_w.shape == (num_instances, num_bodies, 4) + assert torch.allclose(body_com_pos_w, body_com_pose[..., :3], atol=1e-6) + assert torch.allclose(body_com_quat_w, body_com_pose[..., 3:], atol=1e-6) + + # --- Test body_com_vel_w slices --- + body_com_vel = wp.to_torch(articulation_data.body_com_vel_w) + body_com_lin_vel = wp.to_torch(articulation_data.body_com_lin_vel_w) + body_com_ang_vel = wp.to_torch(articulation_data.body_com_ang_vel_w) + + assert body_com_lin_vel.shape == (num_instances, num_bodies, 3) + assert body_com_ang_vel.shape == (num_instances, num_bodies, 3) + assert torch.allclose(body_com_lin_vel, body_com_vel[..., :3], atol=1e-6) + assert torch.allclose(body_com_ang_vel, body_com_vel[..., 3:], atol=1e-6) + + +class TestBodyComPosQuatB: + """Tests the body center of mass position and quaternion in body frame properties. + + Tests the following properties: + - body_com_pos_b: COM position in body frame (direct sim binding) + - body_com_quat_b: COM orientation in body frame (derived from body_com_pose_b) + + Runs the following checks: + - Checks that the returned values have the correct type and shape. + - Checks that body_com_pos_b returns the simulation data. + - Checks that body_com_quat_b is the quaternion slice of body_com_pose_b. + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_body_com_pos_and_quat_b(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that body_com_pos_b and body_com_quat_b have correct types, shapes, and values.""" + articulation_data, mock_view = self._setup_method(num_instances, num_bodies, device) + + # --- Test body_com_pos_b --- + # Check the type and shape + assert articulation_data.body_com_pos_b.shape == (num_instances, num_bodies) + assert articulation_data.body_com_pos_b.dtype == wp.vec3f + + # Mock data is initialized to zeros + expected_pos = torch.zeros((num_instances, num_bodies, 3), device=device) + assert torch.all(wp.to_torch(articulation_data.body_com_pos_b) == expected_pos) + + # Update with random COM positions + com_pos = torch.rand((num_instances, num_bodies, 3), device=device) + mock_view.set_mock_data( + body_com_pos=wp.from_torch(com_pos, dtype=wp.vec3f), + ) + + # Check that the property returns the mock data + assert torch.allclose(wp.to_torch(articulation_data.body_com_pos_b), com_pos, atol=1e-6) + + # Verify reference behavior + body_com_pos_ref = articulation_data.body_com_pos_b + articulation_data.body_com_pos_b.fill_(1.0) + expected_ones = torch.ones((num_instances, num_bodies, 3), device=device) + assert torch.all(wp.to_torch(articulation_data.body_com_pos_b) == expected_ones) + body_com_pos_ref.fill_(2.0) + expected_twos = torch.ones((num_instances, num_bodies, 3), device=device) * 2.0 + assert torch.all(wp.to_torch(articulation_data.body_com_pos_b) == expected_twos) + + # --- Test body_com_quat_b --- + # Check the type and shape + assert articulation_data.body_com_quat_b.shape == (num_instances, num_bodies) + assert articulation_data.body_com_quat_b.dtype == wp.quatf + + # body_com_quat_b is derived from body_com_pose_b which uses identity quaternion + # body_com_pose_b = [body_com_pos_b, identity_quat] + # So body_com_quat_b should be identity quaternion (0, 0, 0, 1) + body_com_quat = wp.to_torch(articulation_data.body_com_quat_b) + expected_quat = torch.zeros((num_instances, num_bodies, 4), device=device) + expected_quat[..., 3] = 1.0 # w component of identity quaternion + + assert torch.allclose(body_com_quat, expected_quat, atol=1e-6) + + +## +# Test Cases -- Backward compatibility. +## + + +# TODO: Remove this test case in the future. +class TestDefaultRootState: + """Tests the deprecated default_root_state property. + + This property combines default_root_pose and default_root_vel into a vec13f state. + It is deprecated in favor of using default_root_pose and default_root_vel directly. + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that it correctly combines default_root_pose and default_root_vel. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_default_root_state(self, mock_newton_manager, num_instances: int, device: str): + """Test that default_root_state correctly combines pose and velocity.""" + articulation_data, _ = self._setup_method(num_instances, device) + + # Check the type and shape + assert articulation_data.default_root_state.shape == (num_instances,) + + # Get the combined state + default_state = wp.to_torch(articulation_data.default_root_state) + assert default_state.shape == (num_instances, 13) + + # Get the individual components + default_pose = wp.to_torch(articulation_data.default_root_pose) + default_vel = wp.to_torch(articulation_data.default_root_vel) + + # Verify the state is the concatenation of pose and velocity + expected_state = torch.cat([default_pose, default_vel], dim=-1) + assert torch.allclose(default_state, expected_state, atol=1e-6) + + # Modify default_root_pose and default_root_vel and verify the state updates + new_pose = torch.zeros((num_instances, 7), device=device) + new_pose[:, :3] = torch.rand((num_instances, 3), device=device) + new_pose[:, 3:] = torch.randn((num_instances, 4), device=device) + new_pose[:, 3:] = torch.nn.functional.normalize(new_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + + new_vel = torch.rand((num_instances, 6), device=device) + + # Set the new values + articulation_data.default_root_pose.assign(wp.from_torch(new_pose, dtype=wp.transformf)) + articulation_data.default_root_vel.assign(wp.from_torch(new_vel, dtype=wp.spatial_vectorf)) + + # Verify the state reflects the new values + updated_state = wp.to_torch(articulation_data.default_root_state) + expected_updated_state = torch.cat([new_pose, new_vel], dim=-1) + assert torch.allclose(updated_state, expected_updated_state, atol=1e-6) + + +# TODO: Remove this test case in the future. +class TestDeprecatedRootProperties: + """Tests the deprecated root pose/velocity properties. + + These are backward compatibility aliases that just return the corresponding new property. + + Tests the following deprecated -> new property mappings: + - root_pose_w -> root_link_pose_w + - root_pos_w -> root_link_pos_w + - root_quat_w -> root_link_quat_w + - root_vel_w -> root_com_vel_w + - root_lin_vel_w -> root_com_lin_vel_w + - root_ang_vel_w -> root_com_ang_vel_w + - root_lin_vel_b -> root_com_lin_vel_b + - root_ang_vel_b -> root_com_ang_vel_b + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_all_deprecated_root_properties(self, mock_newton_manager, num_instances: int, device: str): + """Test that all deprecated root properties match their replacements.""" + articulation_data, mock_view = self._setup_method(num_instances, device) + + # Set up random mock data to ensure non-trivial values + articulation_data._sim_timestamp = 1.0 + + root_pose = torch.zeros((num_instances, 7), device=device) + root_pose[:, :3] = torch.rand((num_instances, 3), device=device) + root_pose[:, 3:] = torch.randn((num_instances, 4), device=device) + root_pose[:, 3:] = torch.nn.functional.normalize(root_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + + com_vel = torch.rand((num_instances, 6), device=device) + body_com_pos = torch.rand((num_instances, 1, 3), device=device) + + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_pose, dtype=wp.transformf), + root_velocities=wp.from_torch(com_vel, dtype=wp.spatial_vectorf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # --- Test root_pose_w -> root_link_pose_w --- + assert torch.allclose( + wp.to_torch(articulation_data.root_pose_w), + wp.to_torch(articulation_data.root_link_pose_w), + atol=1e-6, + ) + + # --- Test root_pos_w -> root_link_pos_w --- + assert torch.allclose( + wp.to_torch(articulation_data.root_pos_w), + wp.to_torch(articulation_data.root_link_pos_w), + atol=1e-6, + ) + + # --- Test root_quat_w -> root_link_quat_w --- + assert torch.allclose( + wp.to_torch(articulation_data.root_quat_w), + wp.to_torch(articulation_data.root_link_quat_w), + atol=1e-6, + ) + + # --- Test root_vel_w -> root_com_vel_w --- + assert torch.allclose( + wp.to_torch(articulation_data.root_vel_w), + wp.to_torch(articulation_data.root_com_vel_w), + atol=1e-6, + ) + + # --- Test root_lin_vel_w -> root_com_lin_vel_w --- + assert torch.allclose( + wp.to_torch(articulation_data.root_lin_vel_w), + wp.to_torch(articulation_data.root_com_lin_vel_w), + atol=1e-6, + ) + + # --- Test root_ang_vel_w -> root_com_ang_vel_w --- + assert torch.allclose( + wp.to_torch(articulation_data.root_ang_vel_w), + wp.to_torch(articulation_data.root_com_ang_vel_w), + atol=1e-6, + ) + + # --- Test root_lin_vel_b -> root_com_lin_vel_b --- + assert torch.allclose( + wp.to_torch(articulation_data.root_lin_vel_b), + wp.to_torch(articulation_data.root_com_lin_vel_b), + atol=1e-6, + ) + + # --- Test root_ang_vel_b -> root_com_ang_vel_b --- + assert torch.allclose( + wp.to_torch(articulation_data.root_ang_vel_b), + wp.to_torch(articulation_data.root_com_ang_vel_b), + atol=1e-6, + ) + + +class TestDeprecatedBodyProperties: + """Tests the deprecated body pose/velocity/acceleration properties. + + These are backward compatibility aliases that just return the corresponding new property. + + Tests the following deprecated -> new property mappings: + - body_pose_w -> body_link_pose_w + - body_pos_w -> body_link_pos_w + - body_quat_w -> body_link_quat_w + - body_vel_w -> body_com_vel_w + - body_lin_vel_w -> body_com_lin_vel_w + - body_ang_vel_w -> body_com_ang_vel_w + - body_acc_w -> body_com_acc_w + - body_lin_acc_w -> body_com_lin_acc_w + - body_ang_acc_w -> body_com_ang_acc_w + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_all_deprecated_body_properties( + self, mock_newton_manager, num_instances: int, num_bodies: int, device: str + ): + """Test that all deprecated body properties match their replacements.""" + articulation_data, mock_view = self._setup_method(num_instances, num_bodies, device) + + # Set up random mock data to ensure non-trivial values + articulation_data._sim_timestamp = 1.0 + + body_pose = torch.zeros((num_instances, num_bodies, 7), device=device) + body_pose[..., :3] = torch.rand((num_instances, num_bodies, 3), device=device) + body_pose[..., 3:] = torch.randn((num_instances, num_bodies, 4), device=device) + body_pose[..., 3:] = torch.nn.functional.normalize(body_pose[..., 3:], p=2.0, dim=-1, eps=1e-12) + + body_vel = torch.rand((num_instances, num_bodies, 6), device=device) + body_com_pos = torch.rand((num_instances, num_bodies, 3), device=device) + + mock_view.set_mock_data( + link_transforms=wp.from_torch(body_pose, dtype=wp.transformf), + link_velocities=wp.from_torch(body_vel, dtype=wp.spatial_vectorf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # --- Test body_pose_w -> body_link_pose_w --- + assert torch.allclose( + wp.to_torch(articulation_data.body_pose_w), + wp.to_torch(articulation_data.body_link_pose_w), + atol=1e-6, + ) + + # --- Test body_pos_w -> body_link_pos_w --- + assert torch.allclose( + wp.to_torch(articulation_data.body_pos_w), + wp.to_torch(articulation_data.body_link_pos_w), + atol=1e-6, + ) + + # --- Test body_quat_w -> body_link_quat_w --- + assert torch.allclose( + wp.to_torch(articulation_data.body_quat_w), + wp.to_torch(articulation_data.body_link_quat_w), + atol=1e-6, + ) + + # --- Test body_vel_w -> body_com_vel_w --- + assert torch.allclose( + wp.to_torch(articulation_data.body_vel_w), + wp.to_torch(articulation_data.body_com_vel_w), + atol=1e-6, + ) + + # --- Test body_lin_vel_w -> body_com_lin_vel_w --- + assert torch.allclose( + wp.to_torch(articulation_data.body_lin_vel_w), + wp.to_torch(articulation_data.body_com_lin_vel_w), + atol=1e-6, + ) + + # --- Test body_ang_vel_w -> body_com_ang_vel_w --- + assert torch.allclose( + wp.to_torch(articulation_data.body_ang_vel_w), + wp.to_torch(articulation_data.body_com_ang_vel_w), + atol=1e-6, + ) + + # --- Test body_acc_w -> body_com_acc_w --- + assert torch.allclose( + wp.to_torch(articulation_data.body_acc_w), + wp.to_torch(articulation_data.body_com_acc_w), + atol=1e-6, + ) + + # --- Test body_lin_acc_w -> body_com_lin_acc_w --- + assert torch.allclose( + wp.to_torch(articulation_data.body_lin_acc_w), + wp.to_torch(articulation_data.body_com_lin_acc_w), + atol=1e-6, + ) + + # --- Test body_ang_acc_w -> body_com_ang_acc_w --- + assert torch.allclose( + wp.to_torch(articulation_data.body_ang_acc_w), + wp.to_torch(articulation_data.body_com_ang_acc_w), + atol=1e-6, + ) + + +class TestDeprecatedComProperties: + """Tests the deprecated COM pose properties. + + Tests the following deprecated -> new property mappings: + - com_pos_b -> body_com_pos_b + - com_quat_b -> body_com_quat_b + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_deprecated_com_properties(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that deprecated COM properties match their replacements.""" + articulation_data, mock_view = self._setup_method(num_instances, num_bodies, device) + + # Set up random mock data + com_pos = torch.rand((num_instances, num_bodies, 3), device=device) + mock_view.set_mock_data( + body_com_pos=wp.from_torch(com_pos, dtype=wp.vec3f), + ) + + # --- Test com_pos_b -> body_com_pos_b --- + assert torch.allclose( + wp.to_torch(articulation_data.com_pos_b), + wp.to_torch(articulation_data.body_com_pos_b), + atol=1e-6, + ) + + # --- Test com_quat_b -> body_com_quat_b --- + assert torch.allclose( + wp.to_torch(articulation_data.com_quat_b), + wp.to_torch(articulation_data.body_com_quat_b), + atol=1e-6, + ) + + +class TestDeprecatedJointMiscProperties: + """Tests the deprecated joint and misc properties. + + Tests the following deprecated -> new property mappings: + - joint_limits -> joint_pos_limits + - joint_friction -> joint_friction_coeff + - applied_torque -> applied_effort + - computed_torque -> computed_effort + - joint_dynamic_friction -> joint_dynamic_friction_coeff + - joint_effort_target -> actuator_effort_target + - joint_viscous_friction -> joint_viscous_friction_coeff + - joint_velocity_limits -> joint_vel_limits + + Note: fixed_tendon_limit -> fixed_tendon_pos_limits is tested separately + as it raises NotImplementedError. + """ + + def _setup_method( + self, num_instances: int, num_joints: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, num_joints, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_deprecated_joint_properties(self, mock_newton_manager, num_instances: int, num_joints: int, device: str): + """Test that deprecated joint properties match their replacements.""" + articulation_data, _ = self._setup_method(num_instances, num_joints, device) + + # --- Test joint_limits -> joint_pos_limits --- + assert torch.allclose( + wp.to_torch(articulation_data.joint_limits), + wp.to_torch(articulation_data.joint_pos_limits), + atol=1e-6, + ) + + # --- Test joint_friction -> joint_friction_coeff --- + assert torch.allclose( + wp.to_torch(articulation_data.joint_friction), + wp.to_torch(articulation_data.joint_friction_coeff), + atol=1e-6, + ) + + # --- Test applied_torque -> applied_effort --- + assert torch.allclose( + wp.to_torch(articulation_data.applied_torque), + wp.to_torch(articulation_data.applied_effort), + atol=1e-6, + ) + + # --- Test computed_torque -> computed_effort --- + assert torch.allclose( + wp.to_torch(articulation_data.computed_torque), + wp.to_torch(articulation_data.computed_effort), + atol=1e-6, + ) + + # --- Test joint_dynamic_friction -> joint_dynamic_friction_coeff --- + assert torch.allclose( + wp.to_torch(articulation_data.joint_dynamic_friction), + wp.to_torch(articulation_data.joint_dynamic_friction_coeff), + atol=1e-6, + ) + + # --- Test joint_effort_target -> actuator_effort_target --- + assert torch.allclose( + wp.to_torch(articulation_data.joint_effort_target), + wp.to_torch(articulation_data.actuator_effort_target), + atol=1e-6, + ) + + # --- Test joint_viscous_friction -> joint_viscous_friction_coeff --- + assert torch.allclose( + wp.to_torch(articulation_data.joint_viscous_friction), + wp.to_torch(articulation_data.joint_viscous_friction_coeff), + atol=1e-6, + ) + + # --- Test joint_velocity_limits -> joint_vel_limits --- + assert torch.allclose( + wp.to_torch(articulation_data.joint_velocity_limits), + wp.to_torch(articulation_data.joint_vel_limits), + atol=1e-6, + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_fixed_tendon_limit_not_implemented(self, mock_newton_manager, device: str): + """Test that fixed_tendon_limit raises NotImplementedError (same as fixed_tendon_pos_limits).""" + articulation_data, _ = self._setup_method(1, 1, device) + + with pytest.raises(NotImplementedError): + _ = articulation_data.fixed_tendon_limit + + +## +# Main +## + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/source/isaaclab_newton/test/assets/rigid_object/__init__.py b/source/isaaclab_newton/test/assets/rigid_object/__init__.py new file mode 100644 index 00000000000..e863516e957 --- /dev/null +++ b/source/isaaclab_newton/test/assets/rigid_object/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for rigid object assets.""" + diff --git a/source/isaaclab_newton/test/assets/rigid_object/test_rigid_object.py b/source/isaaclab_newton/test/assets/rigid_object/test_rigid_object.py new file mode 100644 index 00000000000..1129f238a46 --- /dev/null +++ b/source/isaaclab_newton/test/assets/rigid_object/test_rigid_object.py @@ -0,0 +1,3829 @@ +# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for Articulation class using mocked dependencies. + +This module provides unit tests for the Articulation class that bypass the heavy +initialization process (`_initialize_impl`) which requires a USD stage and real +simulation infrastructure. + +The key technique is to: +1. Create the Articulation object without calling __init__ using object.__new__ +2. Manually set up the required internal state with mock objects +3. Test individual methods in isolation + +This allows testing the mathematical operations and return values without +requiring full simulation integration. +""" + +from __future__ import annotations + +import torch +from unittest.mock import MagicMock, patch + +import pytest +import warp as wp +from isaaclab_newton.assets.articulation.articulation import Articulation +from isaaclab_newton.assets.articulation.articulation_data import ArticulationData +from isaaclab_newton.kernels import vec13f + +from isaaclab.assets.articulation.articulation_cfg import ArticulationCfg + +# TODO: Move these functions to the test utils so they can't be changed in the future. +from isaaclab.utils.math import combine_frame_transforms, quat_apply, quat_inv + +# Import mock classes from shared module +from .mock_interface import MockNewtonArticulationView, MockNewtonModel + +# Initialize Warp +wp.init() + + +## +# Test Factory - Creates Articulation instances without full initialization +## + + +def create_test_articulation( + num_instances: int = 2, + num_joints: int = 6, + num_bodies: int = 7, + device: str = "cuda:0", + is_fixed_base: bool = False, + joint_names: list[str] | None = None, + body_names: list[str] | None = None, + soft_joint_pos_limit_factor: float = 1.0, +) -> tuple[Articulation, MockNewtonArticulationView, MagicMock]: + """Create a test Articulation instance with mocked dependencies. + + This factory bypasses _initialize_impl and manually sets up the internal state, + allowing unit testing of individual methods without requiring USD/simulation. + + Args: + num_instances: Number of environment instances. + num_joints: Number of joints in the articulation. + num_bodies: Number of bodies in the articulation. + device: Device to use ("cpu" or "cuda:0"). + is_fixed_base: Whether the articulation is fixed-base. + joint_names: Custom joint names. Defaults to ["joint_0", "joint_1", ...]. + body_names: Custom body names. Defaults to ["body_0", "body_1", ...]. + soft_joint_pos_limit_factor: Soft joint position limit factor. + + Returns: + A tuple of (articulation, mock_view, mock_newton_manager). + """ + # Generate default names if not provided + if joint_names is None: + joint_names = [f"joint_{i}" for i in range(num_joints)] + if body_names is None: + body_names = [f"body_{i}" for i in range(num_bodies)] + + # Create the Articulation without calling __init__ + articulation = object.__new__(Articulation) + + # Set up the configuration + articulation.cfg = ArticulationCfg( + prim_path="/World/Robot", + soft_joint_pos_limit_factor=soft_joint_pos_limit_factor, + actuators={}, + ) + + # Set up the mock view with all parameters + mock_view = MockNewtonArticulationView( + num_instances=num_instances, + num_bodies=num_bodies, + num_joints=num_joints, + device=device, + is_fixed_base=is_fixed_base, + joint_names=joint_names, + body_names=body_names, + ) + mock_view.set_mock_data() + + # Set the view on the articulation (using object.__setattr__ to bypass type checking) + object.__setattr__(articulation, "_root_view", mock_view) + object.__setattr__(articulation, "_device", device) + + # Create mock NewtonManager + mock_newton_manager = MagicMock() + mock_model = MockNewtonModel() + mock_state = MagicMock() + mock_control = MagicMock() + mock_newton_manager.get_model.return_value = mock_model + mock_newton_manager.get_state_0.return_value = mock_state + mock_newton_manager.get_control.return_value = mock_control + mock_newton_manager.get_dt.return_value = 0.01 + + # Create ArticulationData with the mock view + with patch("isaaclab_newton.assets.articulation.articulation_data.NewtonManager", mock_newton_manager): + data = ArticulationData(mock_view, device) + # Set the names on the data object (normally done by Articulation._initialize_impl) + data.joint_names = joint_names + data.body_names = body_names + object.__setattr__(articulation, "_data", data) + + return articulation, mock_view, mock_newton_manager + + +## +# Test Fixtures +## + + +@pytest.fixture +def mock_newton_manager(): + """Create mock NewtonManager with necessary methods.""" + mock_model = MockNewtonModel() + mock_state = MagicMock() + mock_control = MagicMock() + + # Patch where NewtonManager is used (in the articulation module) + with patch("isaaclab_newton.assets.articulation.articulation.NewtonManager") as MockManager: + MockManager.get_model.return_value = mock_model + MockManager.get_state_0.return_value = mock_state + MockManager.get_control.return_value = mock_control + MockManager.get_dt.return_value = 0.01 + yield MockManager + + +@pytest.fixture +def test_articulation(): + """Create a test articulation with default parameters.""" + articulation, mock_view, mock_manager = create_test_articulation() + yield articulation, mock_view, mock_manager + + +## +# Test Cases -- Properties +## + + +class TestProperties: + """Tests for Articulation properties. + + Tests the following properties: + - data + - num_instances + - is_fixed_base + - num_joints + - num_fixed_tendons + - num_spatial_tendons + - num_bodies + - joint_names + - body_names + """ + + @pytest.mark.parametrize("num_instances", [1, 2, 4]) + def test_num_instances(self, num_instances: int): + """Test the num_instances property returns correct count.""" + articulation, _, _ = create_test_articulation(num_instances=num_instances) + assert articulation.num_instances == num_instances + + @pytest.mark.parametrize("num_joints", [1, 6]) + def test_num_joints(self, num_joints: int): + """Test the num_joints property returns correct count.""" + articulation, _, _ = create_test_articulation(num_joints=num_joints) + assert articulation.num_joints == num_joints + + @pytest.mark.parametrize("num_bodies", [1, 7]) + def test_num_bodies(self, num_bodies: int): + """Test the num_bodies property returns correct count.""" + articulation, _, _ = create_test_articulation(num_bodies=num_bodies) + assert articulation.num_bodies == num_bodies + + @pytest.mark.parametrize("is_fixed_base", [True, False]) + def test_is_fixed_base(self, is_fixed_base: bool): + """Test the is_fixed_base property.""" + articulation, _, _ = create_test_articulation(is_fixed_base=is_fixed_base) + assert articulation.is_fixed_base == is_fixed_base + + # TODO: Update when tendons are supported in Newton. + def test_num_fixed_tendons(self): + """Test that num_fixed_tendons returns 0 (not supported in Newton).""" + articulation, _, _ = create_test_articulation() + # Always returns 0 because fixed tendons are not supported in Newton. + assert articulation.num_fixed_tendons == 0 + + # TODO: Update when tendons are supported in Newton. + def test_num_spatial_tendons(self): + """Test that num_spatial_tendons returns 0 (not supported in Newton).""" + articulation, _, _ = create_test_articulation() + # Always returns 0 because spatial tendons are not supported in Newton. + assert articulation.num_spatial_tendons == 0 + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def testdata_property(self, device: str): + """Test that data property returns ArticulationData instance.""" + articulation, _, _ = create_test_articulation(device=device) + assert isinstance(articulation.data, ArticulationData) + + def test_joint_names(self): + """Test that joint_names returns the correct names.""" + custom_names = ["shoulder", "elbow", "wrist"] + articulation, _, _ = create_test_articulation( + num_joints=3, + joint_names=custom_names, + ) + assert articulation.joint_names == custom_names + + def test_body_names(self): + """Test that body_names returns the correct names.""" + custom_names = ["base", "link1", "link2", "end_effector"] + articulation, _, _ = create_test_articulation( + num_bodies=4, + body_names=custom_names, + ) + assert articulation.body_names == custom_names + + +## +# Test Cases -- Reset +## + + +class TestReset: + """Tests for reset method.""" + + def test_reset(self): + """Test that reset method works properly.""" + articulation, _, _ = create_test_articulation() + articulation.set_external_force_and_torque( + forces=torch.ones(articulation.num_instances, articulation.num_bodies, 3), + torques=torch.ones(articulation.num_instances, articulation.num_bodies, 3), + env_ids=slice(None), + body_ids=slice(None), + body_mask=None, + env_mask=None, + is_global=False, + ) + assert wp.to_torch(articulation.data._sim_bind_body_external_wrench).allclose( + torch.ones_like(wp.to_torch(articulation.data._sim_bind_body_external_wrench)) + ) + articulation.reset() + assert wp.to_torch(articulation.data._sim_bind_body_external_wrench).allclose( + torch.zeros_like(wp.to_torch(articulation.data._sim_bind_body_external_wrench)) + ) + + +## +# Test Cases -- Write Data to Sim. Skipped, this is mostly an integration test. +## + + +## +# Test Cases -- Update +## + + +class TestUpdate: + """Tests for update method.""" + + def test_update(self): + """Test that update method updates the simulation timestamp properly.""" + articulation, _, _ = create_test_articulation() + articulation.update(dt=0.01) + assert articulation.data._sim_timestamp == 0.01 + + +## +# Test Cases -- Finders +## + + +class TestFinders: + """Tests for finder methods.""" + + @pytest.mark.parametrize( + "body_names", + [["body_0", "body_1", "body_2"], ["body_3", "body_4", "body_5"], ["body_1", "body_3", "body_5"], "body_.*"], + ) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_find_bodies(self, body_names: list[str], device: str): + """Test that find_bodies method works properly.""" + articulation, _, _ = create_test_articulation(device=device) + mask, names, indices = articulation.find_bodies(body_names) + if body_names == ["body_0", "body_1", "body_2"]: + mask_ref = torch.zeros((7,), dtype=torch.bool, device=device) + mask_ref[:3] = True + assert names == ["body_0", "body_1", "body_2"] + assert indices == [0, 1, 2] + assert wp.to_torch(mask).allclose(mask_ref) + elif body_names == ["body_3", "body_4", "body_5"]: + mask_ref = torch.zeros((7,), dtype=torch.bool, device=device) + mask_ref[3:6] = True + assert names == ["body_3", "body_4", "body_5"] + assert indices == [3, 4, 5] + assert wp.to_torch(mask).allclose(mask_ref) + elif body_names == ["body_1", "body_3", "body_5"]: + mask_ref = torch.zeros((7,), dtype=torch.bool, device=device) + mask_ref[1] = True + mask_ref[3] = True + mask_ref[5] = True + assert names == ["body_1", "body_3", "body_5"] + assert indices == [1, 3, 5] + assert wp.to_torch(mask).allclose(mask_ref) + elif body_names == "body_.*": + mask_ref = torch.ones((7,), dtype=torch.bool, device=device) + assert names == ["body_0", "body_1", "body_2", "body_3", "body_4", "body_5", "body_6"] + assert indices == [0, 1, 2, 3, 4, 5, 6] + assert wp.to_torch(mask).allclose(mask_ref) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_find_body_with_preserve_order(self, device: str): + """Test that find_bodies method works properly with preserve_order.""" + articulation, _, _ = create_test_articulation(device=device) + mask, names, indices = articulation.find_bodies(["body_5", "body_3", "body_1"], preserve_order=True) + assert names == ["body_5", "body_3", "body_1"] + assert indices == [5, 3, 1] + mask_ref = torch.zeros((7,), dtype=torch.bool, device=device) + mask_ref[1] = True + mask_ref[3] = True + mask_ref[5] = True + assert wp.to_torch(mask).allclose(mask_ref) + + @pytest.mark.parametrize( + "joint_names", + [ + ["joint_0", "joint_1", "joint_2"], + ["joint_3", "joint_4", "joint_5"], + ["joint_1", "joint_3", "joint_5"], + "joint_.*", + ], + ) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_find_joints(self, joint_names: list[str], device: str): + """Test that find_joints method works properly.""" + articulation, _, _ = create_test_articulation(device=device) + mask, names, indices = articulation.find_joints(joint_names) + if joint_names == ["joint_0", "joint_1", "joint_2"]: + mask_ref = torch.zeros((6,), dtype=torch.bool, device=device) + mask_ref[:3] = True + assert names == ["joint_0", "joint_1", "joint_2"] + assert indices == [0, 1, 2] + assert wp.to_torch(mask).allclose(mask_ref) + elif joint_names == ["joint_3", "joint_4", "joint_5"]: + mask_ref = torch.zeros((6,), dtype=torch.bool, device=device) + mask_ref[3:6] = True + assert names == ["joint_3", "joint_4", "joint_5"] + assert indices == [3, 4, 5] + assert wp.to_torch(mask).allclose(mask_ref) + elif joint_names == ["joint_1", "joint_3", "joint_5"]: + mask_ref = torch.zeros((6,), dtype=torch.bool, device=device) + mask_ref[1] = True + mask_ref[3] = True + mask_ref[5] = True + assert names == ["joint_1", "joint_3", "joint_5"] + assert indices == [1, 3, 5] + assert wp.to_torch(mask).allclose(mask_ref) + elif joint_names == "joint_.*": + mask_ref = torch.ones((6,), dtype=torch.bool, device=device) + assert names == ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5"] + assert indices == [0, 1, 2, 3, 4, 5] + assert wp.to_torch(mask).allclose(mask_ref) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_find_joints_with_preserve_order(self, device: str): + """Test that find_joints method works properly with preserve_order.""" + articulation, _, _ = create_test_articulation(device=device) + mask, names, indices = articulation.find_joints(["joint_5", "joint_3", "joint_1"], preserve_order=True) + assert names == ["joint_5", "joint_3", "joint_1"] + assert indices == [5, 3, 1] + mask_ref = torch.zeros((6,), dtype=torch.bool, device=device) + mask_ref[1] = True + mask_ref[3] = True + mask_ref[5] = True + assert wp.to_torch(mask).allclose(mask_ref) + + # TODO: Update when tendons are supported in Newton. + def test_find_fixed_tendons(self): + """Test that find_fixed_tendons method works properly.""" + articulation, _, _ = create_test_articulation() + with pytest.raises(NotImplementedError): + articulation.find_fixed_tendons(["tendon_0", "tendon_1", "tendon_2"]) + + # TODO: Update when tendons are supported in Newton. + def test_find_spatial_tendons(self): + """Test that find_spatial_tendons method works properly.""" + articulation, _, _ = create_test_articulation() + with pytest.raises(NotImplementedError): + articulation.find_spatial_tendons(["tendon_0", "tendon_1", "tendon_2"]) + + +## +# Test Cases -- State Writers +## + + +class TestStateWriters: + """Tests for state writing methods.""" + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0], torch.tensor([0, 1, 2], dtype=torch.int32)]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_state_to_sim_torch(self, device: str, env_ids, num_instances: int): + """Test that write_root_state_to_sim method works properly.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + # Set a non-zero body CoM offset to test the velocity transformation + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + articulation.write_root_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_state_w).allclose(data, atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + data = torch.rand((len(env_ids), 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + articulation.write_root_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_state_w)[env_ids].allclose(data, atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, slice): + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + articulation.write_root_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_state_w)[env_ids].allclose(data, atol=1e-6, rtol=1e-6) + else: + # Update envs 0, 1, 2 + data = torch.rand((3, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + env_ids = env_ids.to(device=device) + articulation.write_root_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_state_w)[env_ids].allclose(data, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_state_to_sim_warp(self, device: str, env_ids, num_instances: int): + """Test that write_root_state_to_sim method works properly.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + # Set a non-zero body CoM offset to test the velocity transformation + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + articulation.write_root_state_to_sim(wp.from_torch(data, dtype=vec13f)) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_state_w).allclose(data, atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + data = torch.rand((len(env_ids), 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Generate warp data + data_warp = torch.ones((num_instances, 13), device=device) + mask_warp = torch.zeros((num_instances,), dtype=torch.bool, device=device) + mask_warp[env_ids] = True + data_warp[env_ids] = data + data_warp = wp.from_torch(data_warp, dtype=vec13f) + mask_warp = wp.from_torch(mask_warp, dtype=wp.bool) + # Write to simulation + articulation.write_root_state_to_sim(data_warp, env_mask=mask_warp) + # Check results + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_state_w)[env_ids].allclose(data, atol=1e-6, rtol=1e-6) + else: + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Generate warp data + data_warp = wp.from_torch(data.clone(), dtype=vec13f) + mask_warp = wp.ones((num_instances,), dtype=wp.bool, device=device) + articulation.write_root_state_to_sim(data_warp, env_mask=mask_warp) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_state_w).allclose(data, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0], torch.tensor([0, 1, 2], dtype=torch.int32)]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_com_state_to_sim_torch(self, device: str, env_ids, num_instances: int): + """Test that write_root_com_state_to_sim method works properly.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + # Set a non-zero body CoM offset to test the velocity transformation + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Write to simulation + articulation.write_root_com_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(data[:, 7:13], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_com_pose_w).allclose(data[:, :7], atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + # Update selected envs + data = torch.rand((len(env_ids), 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Write to simulation + articulation.write_root_com_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_com_vel_w)[env_ids].allclose( + data[:, 7:13], atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.root_com_pose_w)[env_ids].allclose( + data[:, :7], atol=1e-6, rtol=1e-6 + ) + elif isinstance(env_ids, slice): + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + articulation.write_root_com_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(data[:, 7:13], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_com_pose_w).allclose(data[:, :7], atol=1e-6, rtol=1e-6) + else: + # Update envs 0, 1, 2 + data = torch.rand((3, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + env_ids = env_ids.to(device=device) + articulation.write_root_com_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_com_vel_w)[env_ids].allclose( + data[:, 7:13], atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.root_com_pose_w)[env_ids].allclose( + data[:, :7], atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_com_state_to_sim_warp(self, device: str, env_ids, num_instances: int): + """Test that write_root_com_state_to_sim method works properly with warp arrays.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + # Set a non-zero body CoM offset to test the velocity transformation + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + articulation.write_root_com_state_to_sim(wp.from_torch(data, dtype=vec13f)) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(data[:, 7:13], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_com_pose_w).allclose(data[:, :7], atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + data = torch.rand((len(env_ids), 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Generate warp data + data_warp = torch.ones((num_instances, 13), device=device) + mask_warp = torch.zeros((num_instances,), dtype=torch.bool, device=device) + mask_warp[env_ids] = True + data_warp[env_ids] = data + data_warp = wp.from_torch(data_warp, dtype=vec13f) + mask_warp = wp.from_torch(mask_warp, dtype=wp.bool) + # Write to simulation + articulation.write_root_com_state_to_sim(data_warp, env_mask=mask_warp) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_com_vel_w)[env_ids].allclose( + data[:, 7:13], atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.root_com_pose_w)[env_ids].allclose( + data[:, :7], atol=1e-6, rtol=1e-6 + ) + else: + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Generate warp data + data_warp = wp.from_torch(data.clone(), dtype=vec13f) + mask_warp = wp.ones((num_instances,), dtype=wp.bool, device=device) + # Generate reference data + articulation.write_root_com_state_to_sim(data_warp, env_mask=mask_warp) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(data[:, 7:13], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_com_pose_w).allclose(data[:, :7], atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_link_state_to_sim_torch(self, device: str, env_ids, num_instances: int): + """Test that write_root_link_state_to_sim method works properly.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + # Set a non-zero body CoM offset to test the velocity transformation + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Write to simulation + articulation.write_root_link_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_vel_w).allclose(data[:, 7:13], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_link_pose_w).allclose(data[:, :7], atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + # Update selected envs + data = torch.rand((len(env_ids), 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Write to simulation + articulation.write_root_link_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_vel_w)[env_ids].allclose( + data[:, 7:13], atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.root_link_pose_w)[env_ids].allclose( + data[:, :7], atol=1e-6, rtol=1e-6 + ) + elif isinstance(env_ids, slice): + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + articulation.write_root_link_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_vel_w).allclose(data[:, 7:13], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_link_pose_w).allclose(data[:, :7], atol=1e-6, rtol=1e-6) + else: + # Update envs 0, 1, 2 + data = torch.rand((3, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + env_ids = env_ids.to(device=device) + articulation.write_root_link_state_to_sim(data, env_ids=env_ids) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_vel_w)[env_ids].allclose( + data[:, 7:13], atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.root_link_pose_w)[env_ids].allclose( + data[:, :7], atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_link_state_to_sim_warp(self, device: str, env_ids, num_instances: int): + """Test that write_root_link_state_to_sim method works properly with warp arrays.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + # Set a non-zero body CoM offset to test the velocity transformation + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + articulation.write_root_link_state_to_sim(wp.from_torch(data, dtype=vec13f)) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_vel_w).allclose(data[:, 7:13], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_link_pose_w).allclose(data[:, :7], atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + data = torch.rand((len(env_ids), 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Generate warp data + data_warp = torch.ones((num_instances, 13), device=device) + mask_warp = torch.zeros((num_instances,), dtype=torch.bool, device=device) + mask_warp[env_ids] = True + data_warp[env_ids] = data + data_warp = wp.from_torch(data_warp, dtype=vec13f) + mask_warp = wp.from_torch(mask_warp, dtype=wp.bool) + # Generate reference data + data_ref = torch.zeros((num_instances, 13), device=device) + data_ref[env_ids] = data + # Write to simulation + articulation.write_root_link_state_to_sim(data_warp, env_mask=mask_warp) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_vel_w)[env_ids, :].allclose( + data[:, 7:13], atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.root_link_pose_w)[env_ids].allclose( + data[:, :7], atol=1e-6, rtol=1e-6 + ) + else: + # Update all envs + data = torch.rand((num_instances, 13), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Generate warp data + data_warp = wp.from_torch(data.clone(), dtype=vec13f) + mask_warp = wp.ones((num_instances,), dtype=wp.bool, device=device) + # Generate reference data + articulation.write_root_link_state_to_sim(data_warp, env_mask=mask_warp) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_vel_w).allclose(data[:, 7:13], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_link_pose_w).allclose(data[:, :7], atol=1e-6, rtol=1e-6) + + +class TestVelocityWriters: + """Tests for velocity writing methods. + + Tests methods like: + - write_root_link_velocity_to_sim + - write_root_com_velocity_to_sim + """ + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_link_state_to_sim_torch(self, device: str, env_ids, num_instances: int): + """Test that write_root_link_state_to_sim method works properly.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + + # Set a non-zero body CoM offset to test the velocity transformation + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 6), device=device) + # Write to simulation + articulation.write_root_link_velocity_to_sim(data, env_ids=env_ids) + assert wp.to_torch(articulation.data.root_link_vel_w).allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + quat = wp.to_torch(articulation.data.root_link_quat_w) + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[:, 0, :] + # transform input velocity to center of mass frame + root_com_velocity = data.clone() + root_com_velocity[:, :3] += torch.linalg.cross( + root_com_velocity[:, 3:], quat_apply(quat, com_pos_b), dim=-1 + ) + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(root_com_velocity, atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + # Update selected envs + data = torch.rand((len(env_ids), 6), device=device) + # Write to simulation + articulation.write_root_link_velocity_to_sim(data, env_ids=env_ids) + assert wp.to_torch(articulation.data.root_link_vel_w)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + quat = wp.to_torch(articulation.data.root_link_quat_w)[env_ids] + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[env_ids, 0, :] + # transform input velocity to center of mass frame + root_com_velocity = data.clone() + root_com_velocity[:, :3] += torch.linalg.cross( + root_com_velocity[:, 3:], quat_apply(quat, com_pos_b), dim=-1 + ) + assert wp.to_torch(articulation.data.root_com_vel_w)[env_ids, :].allclose( + root_com_velocity, atol=1e-6, rtol=1e-6 + ) + elif isinstance(env_ids, slice): + # Update all envs + data = torch.rand((num_instances, 6), device=device) + articulation.write_root_link_velocity_to_sim(data, env_ids=env_ids) + assert wp.to_torch(articulation.data.root_link_vel_w).allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + quat = wp.to_torch(articulation.data.root_link_quat_w) + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[:, 0, :] + # transform input velocity to center of mass frame + root_com_velocity = data.clone() + root_com_velocity[:, :3] += torch.linalg.cross( + root_com_velocity[:, 3:], quat_apply(quat, com_pos_b), dim=-1 + ) + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(root_com_velocity, atol=1e-6, rtol=1e-6) + else: + # Update envs 0, 1, 2 + data = torch.rand((3, 6), device=device) + env_ids = env_ids.to(device=device) + articulation.write_root_link_velocity_to_sim(data, env_ids=env_ids) + assert wp.to_torch(articulation.data.root_link_vel_w)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + quat = wp.to_torch(articulation.data.root_link_quat_w)[env_ids] + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[env_ids, 0, :] + # transform input velocity to center of mass frame + root_com_velocity = data.clone() + root_com_velocity[:, :3] += torch.linalg.cross( + root_com_velocity[:, 3:], quat_apply(quat, com_pos_b), dim=-1 + ) + assert wp.to_torch(articulation.data.root_com_vel_w)[env_ids, :].allclose( + root_com_velocity, atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_link_velocity_to_sim_with_warp(self, device: str, env_ids, num_instances: int): + """Test that write_root_link_velocity_to_sim method works properly with warp arrays.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + + # Set a non-zero body CoM offset + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 6), device=device) + articulation.write_root_link_velocity_to_sim(wp.from_torch(data, dtype=wp.spatial_vectorf)) + assert wp.to_torch(articulation.data.root_link_vel_w).allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + quat = wp.to_torch(articulation.data.root_link_quat_w) + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[:, 0, :] + # transform input velocity to center of mass frame + root_com_velocity = data.clone() + root_com_velocity[:, :3] += torch.linalg.cross( + root_com_velocity[:, 3:], quat_apply(quat, com_pos_b), dim=-1 + ) + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(root_com_velocity, atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + data = torch.rand((len(env_ids), 6), device=device) + # Generate warp data + data_warp = torch.ones((num_instances, 6), device=device) + mask_warp = torch.zeros((num_instances,), dtype=torch.bool, device=device) + mask_warp[env_ids] = True + data_warp[env_ids] = data + data_warp = wp.from_torch(data_warp, dtype=wp.spatial_vectorf) + mask_warp = wp.from_torch(mask_warp, dtype=wp.bool) + # Write to simulation + articulation.write_root_link_velocity_to_sim(data_warp, env_mask=mask_warp) + assert wp.to_torch(articulation.data.root_link_vel_w)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + quat = wp.to_torch(articulation.data.root_link_quat_w)[env_ids] + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[env_ids, 0, :] + # transform input velocity to center of mass frame + root_com_velocity = data.clone() + root_com_velocity[:, :3] += torch.linalg.cross( + root_com_velocity[:, 3:], quat_apply(quat, com_pos_b), dim=-1 + ) + assert wp.to_torch(articulation.data.root_com_vel_w)[env_ids, :].allclose( + root_com_velocity, atol=1e-6, rtol=1e-6 + ) + else: + # Update all envs + data = torch.rand((num_instances, 6), device=device) + # Generate warp data + data_warp = wp.from_torch(data.clone(), dtype=wp.spatial_vectorf) + mask_warp = wp.ones((num_instances,), dtype=wp.bool, device=device) + # Generate reference data + articulation.write_root_link_velocity_to_sim(data_warp, env_mask=mask_warp) + assert wp.to_torch(articulation.data.root_link_vel_w).allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + quat = wp.to_torch(articulation.data.root_link_quat_w) + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[:, 0, :] + # transform input velocity to center of mass frame + root_com_velocity = data.clone() + root_com_velocity[:, :3] += torch.linalg.cross( + root_com_velocity[:, 3:], quat_apply(quat, com_pos_b), dim=-1 + ) + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(root_com_velocity, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_com_state_to_sim_torch(self, device: str, env_ids, num_instances: int): + """Test that write_root_com_state_to_sim method works properly.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + + # Set a non-zero body CoM offset to test the velocity transformation + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 6), device=device) + # Write to simulation + articulation.write_root_com_velocity_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert torch.all(wp.to_torch(articulation.data.root_link_vel_w)[:, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_link_vel_w)[:, 3:].allclose(data[:, 3:], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(data, atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + # Update selected envs + data = torch.rand((len(env_ids), 6), device=device) + # Write to simulation + articulation.write_root_com_velocity_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert torch.all(wp.to_torch(articulation.data.root_link_vel_w)[env_ids, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_link_vel_w)[env_ids, 3:].allclose( + data[:, 3:], atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.root_com_vel_w)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, slice): + # Update all envs + data = torch.rand((num_instances, 6), device=device) + articulation.write_root_com_velocity_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert torch.all(wp.to_torch(articulation.data.root_link_vel_w)[:, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_link_vel_w)[:, 3:].allclose(data[:, 3:], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(data, atol=1e-6, rtol=1e-6) + else: + # Update envs 0, 1, 2 + data = torch.rand((3, 6), device=device) + env_ids = env_ids.to(device=device) + articulation.write_root_com_velocity_to_sim(data, env_ids=env_ids) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert torch.all(wp.to_torch(articulation.data.root_link_vel_w)[env_ids, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_link_vel_w)[env_ids, 3:].allclose( + data[:, 3:], atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.root_com_vel_w)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_com_velocity_to_sim_with_warp(self, device: str, env_ids, num_instances: int): + """Test that write_root_com_velocity_to_sim method works properly with warp arrays.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + + # Set a non-zero body CoM offset + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 6), device=device) + articulation.write_root_com_velocity_to_sim(wp.from_torch(data, dtype=wp.spatial_vectorf)) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert torch.all(wp.to_torch(articulation.data.root_link_vel_w)[:, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_link_vel_w)[:, 3:].allclose(data[:, 3:], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(data, atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + data = torch.rand((len(env_ids), 6), device=device) + # Generate warp data + data_warp = torch.ones((num_instances, 6), device=device) + mask_warp = torch.zeros((num_instances,), dtype=torch.bool, device=device) + mask_warp[env_ids] = True + data_warp[env_ids] = data + data_warp = wp.from_torch(data_warp, dtype=wp.spatial_vectorf) + mask_warp = wp.from_torch(mask_warp, dtype=wp.bool) + # Write to simulation + articulation.write_root_com_velocity_to_sim(data_warp, env_mask=mask_warp) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert torch.all(wp.to_torch(articulation.data.root_link_vel_w)[env_ids, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_link_vel_w)[env_ids, 3:].allclose( + data[:, 3:], atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.root_com_vel_w)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + else: + # Update all envs + data = torch.rand((num_instances, 6), device=device) + # Generate warp data + data_warp = wp.from_torch(data.clone(), dtype=wp.spatial_vectorf) + mask_warp = wp.ones((num_instances,), dtype=wp.bool, device=device) + # Generate reference data + articulation.write_root_com_velocity_to_sim(data_warp, env_mask=mask_warp) + assert articulation.data._root_link_vel_w.timestamp == -1.0 + assert torch.all(wp.to_torch(articulation.data.root_link_vel_w)[:, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_link_vel_w)[:, 3:].allclose(data[:, 3:], atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.root_com_vel_w).allclose(data, atol=1e-6, rtol=1e-6) + + +class TestPoseWriters: + """Tests for pose writing methods. + + Tests methods like: + - write_root_link_pose_to_sim + - write_root_com_pose_to_sim + """ + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_link_pose_to_sim_torch(self, device: str, env_ids, num_instances: int): + """Test that write_root_link_pose_to_sim method works properly.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + # Set a non-zero body CoM offset to test the pose transformation + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if env_ids is None: + # Update all envs + data = torch.rand((num_instances, 7), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Write to simulation + articulation.write_root_link_pose_to_sim(data, env_ids=env_ids) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_pose_w).allclose(data, atol=1e-6, rtol=1e-6) + assert torch.all(wp.to_torch(articulation.data.root_com_pose_w)[:, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_com_pose_w)[:, 3:].allclose(data[:, 3:], atol=1e-6, rtol=1e-6) + elif isinstance(env_ids, list): + # Update selected envs + data = torch.rand((len(env_ids), 7), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Write to simulation + articulation.write_root_link_pose_to_sim(data, env_ids=env_ids) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_pose_w)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + assert torch.all(wp.to_torch(articulation.data.root_com_pose_w)[env_ids, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_com_pose_w)[env_ids, 3:].allclose( + data[:, 3:], atol=1e-6, rtol=1e-6 + ) + elif isinstance(env_ids, slice): + # Update all envs + data = torch.rand((num_instances, 7), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + articulation.write_root_link_pose_to_sim(data, env_ids=env_ids) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_pose_w).allclose(data, atol=1e-6, rtol=1e-6) + assert torch.all(wp.to_torch(articulation.data.root_com_pose_w)[:, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_com_pose_w)[:, 3:].allclose(data[:, 3:], atol=1e-6, rtol=1e-6) + else: + # Update envs 0, 1, 2 + data = torch.rand((3, 7), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + env_ids = env_ids.to(device=device) + articulation.write_root_link_pose_to_sim(data, env_ids=env_ids) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_pose_w)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + assert torch.all(wp.to_torch(articulation.data.root_com_pose_w)[env_ids, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_com_pose_w)[env_ids, 3:].allclose( + data[:, 3:], atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_link_pose_to_sim_with_warp(self, device: str, env_ids, num_instances: int): + """Test that write_root_link_pose_to_sim method works properly with warp arrays.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + # Set a non-zero body CoM offset + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if (env_ids is None) or (isinstance(env_ids, slice)): + data = torch.rand((num_instances, 7), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Update all envs + articulation.write_root_link_pose_to_sim(wp.from_torch(data, dtype=wp.transformf)) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_pose_w).allclose(data, atol=1e-6, rtol=1e-6) + assert torch.all(wp.to_torch(articulation.data.root_com_pose_w)[:, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_com_pose_w)[:, 3:].allclose(data[:, 3:], atol=1e-6, rtol=1e-6) + else: + data = torch.rand((len(env_ids), 7), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Generate warp data + data_warp = torch.ones((num_instances, 7), device=device) + mask_warp = torch.zeros((num_instances,), dtype=torch.bool, device=device) + mask_warp[env_ids] = True + data_warp[env_ids] = data + data_warp = wp.from_torch(data_warp, dtype=wp.transformf) + mask_warp = wp.from_torch(mask_warp, dtype=wp.bool) + # Write to simulation + articulation.write_root_link_pose_to_sim(data_warp, env_mask=mask_warp) + assert articulation.data._root_com_pose_w.timestamp == -1.0 + assert wp.to_torch(articulation.data.root_link_pose_w)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + assert torch.all(wp.to_torch(articulation.data.root_com_pose_w)[env_ids, :3] != data[:, :3]) + assert wp.to_torch(articulation.data.root_com_pose_w)[env_ids, 3:].allclose( + data[:, 3:], atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_com_pose_to_sim_torch(self, device: str, env_ids, num_instances: int): + """Test that write_root_com_pose_to_sim method works properly.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + # Set a non-zero body CoM offset to test the velocity transformation + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if (env_ids is None) or (isinstance(env_ids, slice)): + data = torch.rand((num_instances, 7), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Write to simulation + articulation.write_root_com_pose_to_sim(data, env_ids=env_ids) + assert wp.to_torch(articulation.data.root_com_pose_w).allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[:, 0, :] + com_quat_b = wp.to_torch(articulation.data.body_com_quat_b)[:, 0, :] + # transform input CoM pose to link frame + root_link_pos, root_link_quat = combine_frame_transforms( + data[..., :3], + data[..., 3:7], + quat_apply(quat_inv(com_quat_b), -com_pos_b), + quat_inv(com_quat_b), + ) + root_link_pose = torch.cat((root_link_pos, root_link_quat), dim=-1) + assert wp.to_torch(articulation.data.root_link_pose_w).allclose(root_link_pose, atol=1e-6, rtol=1e-6) + else: + if isinstance(env_ids, torch.Tensor): + env_ids = env_ids.to(device=device) + # Update selected envs + data = torch.rand((len(env_ids), 7), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Write to simulation + articulation.write_root_com_pose_to_sim(data, env_ids=env_ids) + assert wp.to_torch(articulation.data.root_com_pose_w)[env_ids].allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[env_ids, 0, :] + com_quat_b = wp.to_torch(articulation.data.body_com_quat_b)[env_ids, 0, :] + # transform input CoM pose to link frame + root_link_pos, root_link_quat = combine_frame_transforms( + data[..., :3], + data[..., 3:7], + quat_apply(quat_inv(com_quat_b), -com_pos_b), + quat_inv(com_quat_b), + ) + root_link_pose = torch.cat((root_link_pos, root_link_quat), dim=-1) + assert wp.to_torch(articulation.data.root_link_pose_w)[env_ids, :].allclose( + root_link_pose, atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_root_com_pose_to_sim_with_warp(self, device: str, env_ids, num_instances: int): + """Test that write_root_com_pose_to_sim method works properly with warp arrays.""" + articulation, mock_view, _ = create_test_articulation(num_instances=num_instances, device=device) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + # Set a non-zero body CoM offset + body_com_offset = torch.tensor([0.1, 0.01, 0.05], device=device) + body_comdata = body_com_offset.unsqueeze(0).unsqueeze(0).expand(num_instances, articulation.num_bodies, 3) + root_transforms = torch.rand((num_instances, 7), device=device) + root_transforms[:, 3:7] = torch.nn.functional.normalize(root_transforms[:, 3:7], p=2.0, dim=-1) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_transforms, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_comdata.clone(), dtype=wp.vec3f), + ) + for _ in range(5): + if (env_ids is None) or (isinstance(env_ids, slice)): + data = torch.rand((num_instances, 7), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + articulation.write_root_com_pose_to_sim(wp.from_torch(data, dtype=wp.transformf)) + assert wp.to_torch(articulation.data.root_com_pose_w).allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[:, 0, :] + com_quat_b = wp.to_torch(articulation.data.body_com_quat_b)[:, 0, :] + # transform input CoM pose to link frame + root_link_pos, root_link_quat = combine_frame_transforms( + data[..., :3], + data[..., 3:7], + quat_apply(quat_inv(com_quat_b), -com_pos_b), + quat_inv(com_quat_b), + ) + root_link_pose = torch.cat((root_link_pos, root_link_quat), dim=-1) + assert wp.to_torch(articulation.data.root_link_pose_w).allclose(root_link_pose, atol=1e-6, rtol=1e-6) + else: + data = torch.rand((len(env_ids), 7), device=device) + data[:, 3:7] = torch.nn.functional.normalize(data[:, 3:7], p=2.0, dim=-1) + # Generate warp data + data_warp = torch.ones((num_instances, 7), device=device) + mask_warp = torch.zeros((num_instances,), dtype=torch.bool, device=device) + mask_warp[env_ids] = True + data_warp[env_ids] = data + data_warp = wp.from_torch(data_warp, dtype=wp.transformf) + mask_warp = wp.from_torch(mask_warp, dtype=wp.bool) + # Write to simulation + articulation.write_root_com_pose_to_sim(data_warp, env_mask=mask_warp) + assert wp.to_torch(articulation.data.root_com_pose_w)[env_ids].allclose(data, atol=1e-6, rtol=1e-6) + # get CoM pose in link frame + com_pos_b = wp.to_torch(articulation.data.body_com_pos_b)[env_ids, 0, :] + com_quat_b = wp.to_torch(articulation.data.body_com_quat_b)[env_ids, 0, :] + # transform input CoM pose to link frame + root_link_pos, root_link_quat = combine_frame_transforms( + data[..., :3], + data[..., 3:7], + quat_apply(quat_inv(com_quat_b), -com_pos_b), + quat_inv(com_quat_b), + ) + root_link_pose = torch.cat((root_link_pos, root_link_quat), dim=-1) + assert wp.to_torch(articulation.data.root_link_pose_w)[env_ids, :].allclose( + root_link_pose, atol=1e-6, rtol=1e-6 + ) + + +class TestJointState: + """Tests for joint state writing methods. + + Tests methods: + - write_joint_state_to_sim + - write_joint_position_to_sim + - write_joint_velocity_to_sim + """ + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_state_to_sim_torch(self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_joints=num_joints, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_joints == 1: + if (joint_ids is not None) and (not isinstance(joint_ids, slice)): + joint_ids = [0] + + for _ in range(5): + if (env_ids is None) or (isinstance(env_ids, slice)): + if (joint_ids is None) or (isinstance(joint_ids, slice)): + # All envs and joints + data1 = torch.rand((num_instances, num_joints), device=device) + data2 = torch.rand((num_instances, num_joints), device=device) + articulation.write_joint_state_to_sim(data1, data2, env_ids=env_ids, joint_ids=joint_ids) + assert wp.to_torch(articulation.data.joint_pos).allclose(data1, atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.joint_vel).allclose(data2, atol=1e-6, rtol=1e-6) + else: + # All envs and selected joints + data1 = torch.rand((num_instances, len(joint_ids)), device=device) + data2 = torch.rand((num_instances, len(joint_ids)), device=device) + articulation.write_joint_state_to_sim(data1, data2, env_ids=env_ids, joint_ids=joint_ids) + assert wp.to_torch(articulation.data.joint_pos)[:, joint_ids].allclose(data1, atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.joint_vel)[:, joint_ids].allclose(data2, atol=1e-6, rtol=1e-6) + else: + if (joint_ids is None) or (isinstance(joint_ids, slice)): + # Selected envs and all joints + data1 = torch.rand((len(env_ids), num_joints), device=device) + data2 = torch.rand((len(env_ids), num_joints), device=device) + articulation.write_joint_state_to_sim(data1, data2, env_ids=env_ids, joint_ids=joint_ids) + assert wp.to_torch(articulation.data.joint_pos)[env_ids, :].allclose(data1, atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.joint_vel)[env_ids, :].allclose(data2, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + data1 = torch.rand((len(env_ids), len(joint_ids)), device=device) + data2 = torch.rand((len(env_ids), len(joint_ids)), device=device) + articulation.write_joint_state_to_sim(data1, data2, env_ids=env_ids, joint_ids=joint_ids) + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + assert wp.to_torch(articulation.data.joint_pos)[env_ids_, joint_ids].allclose( + data1, atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.joint_vel)[env_ids_, joint_ids].allclose( + data2, atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_state_to_sim_warp(self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_joints=num_joints, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_joints == 1: + if (joint_ids is not None) and (not isinstance(joint_ids, slice)): + joint_ids = [0] + + for _ in range(5): + if env_ids is None: + if joint_ids is None: + # All envs and joints + data1 = torch.rand((num_instances, num_joints), device=device) + data2 = torch.rand((num_instances, num_joints), device=device) + articulation.write_joint_state_to_sim( + wp.from_torch(data1, dtype=wp.float32), + wp.from_torch(data2, dtype=wp.float32), + env_mask=None, + joint_mask=None, + ) + assert wp.to_torch(articulation.data.joint_pos).allclose(data1, atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.joint_vel).allclose(data2, atol=1e-6, rtol=1e-6) + else: + # All envs and selected joints + data1 = torch.rand((num_instances, len(joint_ids)), device=device) + data2 = torch.rand((num_instances, len(joint_ids)), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[:, joint_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + data2_warp = torch.ones((num_instances, num_joints), device=device) + data2_warp[:, joint_ids] = data2 + data2_warp = wp.from_torch(data2_warp, dtype=wp.float32) + joint_mask = torch.zeros((num_joints,), dtype=torch.bool, device=device) + joint_mask[joint_ids] = True + joint_mask = wp.from_torch(joint_mask, dtype=wp.bool) + articulation.write_joint_state_to_sim(data1_warp, data2_warp, env_mask=None, joint_mask=joint_mask) + assert wp.to_torch(articulation.data.joint_pos)[:, joint_ids].allclose(data1, atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.joint_vel)[:, joint_ids].allclose(data2, atol=1e-6, rtol=1e-6) + else: + if joint_ids is None: + # Selected envs and all joints + data1 = torch.rand((len(env_ids), num_joints), device=device) + data2 = torch.rand((len(env_ids), num_joints), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[env_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + data2_warp = torch.ones((num_instances, num_joints), device=device) + data2_warp[env_ids] = data2 + data2_warp = wp.from_torch(data2_warp, dtype=wp.float32) + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + articulation.write_joint_state_to_sim( + wp.from_torch(data1, dtype=wp.float32), + wp.from_torch(data2, dtype=wp.float32), + env_mask=env_mask, + joint_mask=None, + ) + assert wp.to_torch(articulation.data.joint_pos)[env_ids, :].allclose(data1, atol=1e-6, rtol=1e-6) + assert wp.to_torch(articulation.data.joint_vel)[env_ids, :].allclose(data2, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + data1 = torch.rand((len(env_ids), len(joint_ids)), device=device) + data2 = torch.rand((len(env_ids), len(joint_ids)), device=device) + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[env_ids_, joint_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + data2_warp = torch.ones((num_instances, num_joints), device=device) + data2_warp[env_ids_, joint_ids] = data2 + data2_warp = wp.from_torch(data2_warp, dtype=wp.float32) + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + joint_mask = torch.zeros((num_joints,), dtype=torch.bool, device=device) + joint_mask[joint_ids] = True + joint_mask = wp.from_torch(joint_mask, dtype=wp.bool) + articulation.write_joint_state_to_sim( + data1_warp, data2_warp, env_mask=env_mask, joint_mask=joint_mask + ) + assert wp.to_torch(articulation.data.joint_pos)[env_ids_, joint_ids].allclose( + data1, atol=1e-6, rtol=1e-6 + ) + assert wp.to_torch(articulation.data.joint_vel)[env_ids_, joint_ids].allclose( + data2, atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_position_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_joints=num_joints, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_joints == 1: + if (joint_ids is not None) and (not isinstance(joint_ids, slice)): + joint_ids = [0] + + for _ in range(5): + if (env_ids is None) or (isinstance(env_ids, slice)): + if (joint_ids is None) or (isinstance(joint_ids, slice)): + # All envs and joints + data1 = torch.rand((num_instances, num_joints), device=device) + articulation.write_joint_position_to_sim(data1, env_ids=env_ids, joint_ids=joint_ids) + assert wp.to_torch(articulation.data.joint_pos).allclose(data1, atol=1e-6, rtol=1e-6) + else: + # All envs and selected joints + data1 = torch.rand((num_instances, len(joint_ids)), device=device) + articulation.write_joint_position_to_sim(data1, env_ids=env_ids, joint_ids=joint_ids) + assert wp.to_torch(articulation.data.joint_pos)[:, joint_ids].allclose(data1, atol=1e-6, rtol=1e-6) + else: + if (joint_ids is None) or (isinstance(joint_ids, slice)): + # Selected envs and all joints + data1 = torch.rand((len(env_ids), num_joints), device=device) + articulation.write_joint_position_to_sim(data1, env_ids=env_ids, joint_ids=joint_ids) + assert wp.to_torch(articulation.data.joint_pos)[env_ids, :].allclose(data1, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + data1 = torch.rand((len(env_ids), len(joint_ids)), device=device) + articulation.write_joint_position_to_sim(data1, env_ids=env_ids, joint_ids=joint_ids) + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + assert wp.to_torch(articulation.data.joint_pos)[env_ids_, joint_ids].allclose( + data1, atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_position_to_sim_warp( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_joints=num_joints, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_joints == 1: + if (joint_ids is not None) and (not isinstance(joint_ids, slice)): + joint_ids = [0] + + for _ in range(5): + if env_ids is None: + if joint_ids is None: + # All envs and joints + data1 = torch.rand((num_instances, num_joints), device=device) + articulation.write_joint_position_to_sim( + wp.from_torch(data1, dtype=wp.float32), env_mask=None, joint_mask=None + ) + assert wp.to_torch(articulation.data.joint_pos).allclose(data1, atol=1e-6, rtol=1e-6) + else: + # All envs and selected joints + data1 = torch.rand((num_instances, len(joint_ids)), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[:, joint_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + joint_mask = torch.zeros((num_joints,), dtype=torch.bool, device=device) + joint_mask[joint_ids] = True + joint_mask = wp.from_torch(joint_mask, dtype=wp.bool) + articulation.write_joint_position_to_sim(data1_warp, env_mask=None, joint_mask=joint_mask) + assert wp.to_torch(articulation.data.joint_pos)[:, joint_ids].allclose(data1, atol=1e-6, rtol=1e-6) + else: + if joint_ids is None: + # Selected envs and all joints + data1 = torch.rand((len(env_ids), num_joints), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[env_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + articulation.write_joint_position_to_sim(data1_warp, env_mask=env_mask, joint_mask=None) + assert wp.to_torch(articulation.data.joint_pos)[env_ids, :].allclose(data1, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + data1 = torch.rand((len(env_ids), len(joint_ids)), device=device) + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[env_ids_, joint_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + joint_mask = torch.zeros((num_joints,), dtype=torch.bool, device=device) + joint_mask[joint_ids] = True + joint_mask = wp.from_torch(joint_mask, dtype=wp.bool) + articulation.write_joint_position_to_sim(data1_warp, env_mask=env_mask, joint_mask=joint_mask) + assert wp.to_torch(articulation.data.joint_pos)[env_ids_, joint_ids].allclose( + data1, atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], slice(None), [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_velocity_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_joints=num_joints, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_joints == 1: + if (joint_ids is not None) and (not isinstance(joint_ids, slice)): + joint_ids = [0] + + for _ in range(5): + if (env_ids is None) or (isinstance(env_ids, slice)): + if (joint_ids is None) or (isinstance(joint_ids, slice)): + # All envs and joints + data1 = torch.rand((num_instances, num_joints), device=device) + articulation.write_joint_velocity_to_sim(data1, env_ids=env_ids, joint_ids=joint_ids) + assert wp.to_torch(articulation.data.joint_vel).allclose(data1, atol=1e-6, rtol=1e-6) + else: + # All envs and selected joints + data1 = torch.rand((num_instances, len(joint_ids)), device=device) + articulation.write_joint_velocity_to_sim(data1, env_ids=env_ids, joint_ids=joint_ids) + assert wp.to_torch(articulation.data.joint_vel)[:, joint_ids].allclose(data1, atol=1e-6, rtol=1e-6) + else: + if (joint_ids is None) or (isinstance(joint_ids, slice)): + # Selected envs and all joints + data1 = torch.rand((len(env_ids), num_joints), device=device) + articulation.write_joint_velocity_to_sim(data1, env_ids=env_ids, joint_ids=joint_ids) + assert wp.to_torch(articulation.data.joint_vel)[env_ids, :].allclose(data1, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + data1 = torch.rand((len(env_ids), len(joint_ids)), device=device) + articulation.write_joint_velocity_to_sim(data1, env_ids=env_ids, joint_ids=joint_ids) + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + assert wp.to_torch(articulation.data.joint_vel)[env_ids_, joint_ids].allclose( + data1, atol=1e-6, rtol=1e-6 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_velocity_to_sim_warp( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_joints=num_joints, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_joints == 1: + if (joint_ids is not None) and (not isinstance(joint_ids, slice)): + joint_ids = [0] + + for _ in range(5): + if env_ids is None: + if joint_ids is None: + # All envs and joints + data1 = torch.rand((num_instances, num_joints), device=device) + articulation.write_joint_velocity_to_sim( + wp.from_torch(data1, dtype=wp.float32), env_mask=None, joint_mask=None + ) + assert wp.to_torch(articulation.data.joint_vel).allclose(data1, atol=1e-6, rtol=1e-6) + else: + # All envs and selected joints + data1 = torch.rand((num_instances, len(joint_ids)), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[:, joint_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + joint_mask = torch.zeros((num_joints,), dtype=torch.bool, device=device) + joint_mask[joint_ids] = True + joint_mask = wp.from_torch(joint_mask, dtype=wp.bool) + articulation.write_joint_velocity_to_sim(data1_warp, env_mask=None, joint_mask=joint_mask) + assert wp.to_torch(articulation.data.joint_vel)[:, joint_ids].allclose(data1, atol=1e-6, rtol=1e-6) + else: + if joint_ids is None: + # Selected envs and all joints + data1 = torch.rand((len(env_ids), num_joints), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[env_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + articulation.write_joint_velocity_to_sim(data1_warp, env_mask=env_mask, joint_mask=None) + assert wp.to_torch(articulation.data.joint_vel)[env_ids, :].allclose(data1, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + data1 = torch.rand((len(env_ids), len(joint_ids)), device=device) + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[env_ids_, joint_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + joint_mask = torch.zeros((num_joints,), dtype=torch.bool, device=device) + joint_mask[joint_ids] = True + joint_mask = wp.from_torch(joint_mask, dtype=wp.bool) + articulation.write_joint_velocity_to_sim(data1_warp, env_mask=env_mask, joint_mask=joint_mask) + assert wp.to_torch(articulation.data.joint_vel)[env_ids_, joint_ids].allclose( + data1, atol=1e-6, rtol=1e-6 + ) + + +## +# Test Cases -- Simulation Parameters Writers. +## + + +class TestWriteJointPropertiesToSim: + """Tests for writing joint properties to the simulation. + + Tests methods: + - write_joint_stiffness_to_sim + - write_joint_damping_to_sim + - write_joint_position_limit_to_sim + - write_joint_velocity_limit_to_sim + - write_joint_effort_limit_to_sim + - write_joint_armature_to_sim + - write_joint_friction_coefficient_to_sim + - write_joint_dynamic_friction_coefficient_to_sim + - write_joint_joint_friction_to_sim + - write_joint_limits_to_sim + """ + + def generic_test_property_writer_torch( + self, + device: str, + env_ids, + joint_ids, + num_instances: int, + num_joints: int, + writer_function_name: str, + property_name: str, + ): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_joints=num_joints, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_joints == 1: + if (joint_ids is not None) and (not isinstance(joint_ids, slice)): + joint_ids = [0] + + writer_function = getattr(articulation, writer_function_name) + + for i in range(5): + if (env_ids is None) or (isinstance(env_ids, slice)): + if (joint_ids is None) or (isinstance(joint_ids, slice)): + # All envs and joints + if i % 2 == 0: + data1 = torch.rand((num_instances, num_joints), device=device) + else: + data1 = float(i) + writer_function(data1, env_ids=env_ids, joint_ids=joint_ids) + property_data = getattr(articulation.data, property_name) + if i % 2 == 0: + assert wp.to_torch(property_data).allclose(data1, atol=1e-6, rtol=1e-6) + else: + assert wp.to_torch(property_data).allclose( + data1 * torch.ones((num_instances, num_joints), device=device), atol=1e-6, rtol=1e-6 + ) + else: + # All envs and selected joints + if i % 2 == 0: + data1 = torch.rand((num_instances, len(joint_ids)), device=device) + else: + data1 = float(i) + data_ref = torch.zeros((num_instances, num_joints), device=device) + data_ref[:, joint_ids] = data1 + writer_function(data1, env_ids=env_ids, joint_ids=joint_ids) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + else: + if (joint_ids is None) or (isinstance(joint_ids, slice)): + # Selected envs and all joints + if i % 2 == 0: + data1 = torch.rand((len(env_ids), num_joints), device=device) + else: + data1 = float(i) + data_ref = torch.zeros((num_instances, num_joints), device=device) + data_ref[env_ids, :] = data1 + writer_function(data1, env_ids=env_ids, joint_ids=joint_ids) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + if i % 2 == 0: + data1 = torch.rand((len(env_ids), len(joint_ids)), device=device) + else: + data1 = float(i) + writer_function(data1, env_ids=env_ids, joint_ids=joint_ids) + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + data_ref = torch.zeros((num_instances, num_joints), device=device) + data_ref[env_ids_, joint_ids] = data1 + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + + def generic_test_property_writer_warp( + self, + device: str, + env_ids, + joint_ids, + num_instances: int, + num_joints: int, + writer_function_name: str, + property_name: str, + ): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_joints=num_joints, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_joints == 1: + if (joint_ids is not None) and (not isinstance(joint_ids, slice)): + joint_ids = [0] + + writer_function = getattr(articulation, writer_function_name) + + for i in range(5): + if env_ids is None: + if joint_ids is None: + # All envs and joints + if i % 2 == 0: + data1 = torch.rand((num_instances, num_joints), device=device) + data1_warp = wp.from_torch(data1, dtype=wp.float32) + else: + data1 = float(i) + data1_warp = data1 + writer_function(data1_warp, env_mask=None, joint_mask=None) + property_data = getattr(articulation.data, property_name) + if i % 2 == 0: + assert wp.to_torch(property_data).allclose(data1, atol=1e-6, rtol=1e-6) + else: + assert wp.to_torch(property_data).allclose( + data1 * torch.ones((num_instances, num_joints), device=device), atol=1e-6, rtol=1e-6 + ) + else: + # All envs and selected joints + if i % 2 == 0: + data1 = torch.rand((num_instances, len(joint_ids)), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[:, joint_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + else: + data1 = float(i) + data1_warp = data1 + joint_mask = torch.zeros((num_joints,), dtype=torch.bool, device=device) + joint_mask[joint_ids] = True + joint_mask = wp.from_torch(joint_mask, dtype=wp.bool) + data_ref = torch.zeros((num_instances, num_joints), device=device) + data_ref[:, joint_ids] = data1 + writer_function(data1_warp, env_mask=None, joint_mask=joint_mask) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + else: + if joint_ids is None: + # Selected envs and all joints + if i % 2 == 0: + data1 = torch.rand((len(env_ids), num_joints), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[env_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + else: + data1 = float(i) + data1_warp = data1 + data_ref = torch.zeros((num_instances, num_joints), device=device) + data_ref[env_ids, :] = data1 + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + writer_function(data1_warp, env_mask=env_mask, joint_mask=None) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + if i % 2 == 0: + data1 = torch.rand((len(env_ids), len(joint_ids)), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[env_ids_, joint_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + else: + data1 = float(i) + data1_warp = data1 + data_ref = torch.zeros((num_instances, num_joints), device=device) + data_ref[env_ids_, joint_ids] = data1 + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + joint_mask = torch.zeros((num_joints,), dtype=torch.bool, device=device) + joint_mask[joint_ids] = True + joint_mask = wp.from_torch(joint_mask, dtype=wp.bool) + writer_function(data1_warp, env_mask=env_mask, joint_mask=joint_mask) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + + def generic_test_property_writer_torch_dual( + self, + device: str, + env_ids, + joint_ids, + num_instances: int, + num_joints: int, + writer_function_name: str, + property_name: str, + ): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_joints=num_joints, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_joints == 1: + if (joint_ids is not None) and (not isinstance(joint_ids, slice)): + joint_ids = [0] + + writer_function = getattr(articulation, writer_function_name) + + for _ in range(5): + if (env_ids is None) or (isinstance(env_ids, slice)): + if (joint_ids is None) or (isinstance(joint_ids, slice)): + # All envs and joints + data1 = torch.rand((num_instances, num_joints), device=device) + data2 = torch.rand((num_instances, num_joints), device=device) + writer_function(data1, data2, env_ids=env_ids, joint_ids=joint_ids) + data = torch.cat([data1.unsqueeze(-1), data2.unsqueeze(-1)], dim=-1) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data, atol=1e-6, rtol=1e-6) + else: + # All envs and selected joints + data1 = torch.rand((num_instances, len(joint_ids)), device=device) + data2 = torch.rand((num_instances, len(joint_ids)), device=device) + writer_function(data1, data2, env_ids=env_ids, joint_ids=joint_ids) + data = torch.cat([data1.unsqueeze(-1), data2.unsqueeze(-1)], dim=-1) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data)[:, joint_ids].allclose(data, atol=1e-6, rtol=1e-6) + else: + if (joint_ids is None) or (isinstance(joint_ids, slice)): + # Selected envs and all joints + data1 = torch.rand((len(env_ids), num_joints), device=device) + data2 = torch.rand((len(env_ids), num_joints), device=device) + writer_function(data1, data2, env_ids=env_ids, joint_ids=joint_ids) + data = torch.cat([data1.unsqueeze(-1), data2.unsqueeze(-1)], dim=-1) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + data1 = torch.rand((len(env_ids), len(joint_ids)), device=device) + data2 = torch.rand((len(env_ids), len(joint_ids)), device=device) + writer_function(data1, data2, env_ids=env_ids, joint_ids=joint_ids) + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + property_data = getattr(articulation.data, property_name) + data = torch.cat([data1.unsqueeze(-1), data2.unsqueeze(-1)], dim=-1) + assert wp.to_torch(property_data)[env_ids_, joint_ids].allclose(data, atol=1e-6, rtol=1e-6) + + def generic_test_property_writer_warp_dual( + self, + device: str, + env_ids, + joint_ids, + num_instances: int, + num_joints: int, + writer_function_name: str, + property_name: str, + ): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_joints=num_joints, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_joints == 1: + if (joint_ids is not None) and (not isinstance(joint_ids, slice)): + joint_ids = [0] + + writer_function = getattr(articulation, writer_function_name) + + for _ in range(5): + if env_ids is None: + if joint_ids is None: + # All envs and joints + data1 = torch.rand((num_instances, num_joints), device=device) + data2 = torch.rand((num_instances, num_joints), device=device) + writer_function( + wp.from_torch(data1, dtype=wp.float32), + wp.from_torch(data2, dtype=wp.float32), + env_mask=None, + joint_mask=None, + ) + data = torch.cat([data1.unsqueeze(-1), data2.unsqueeze(-1)], dim=-1) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data, atol=2e-6, rtol=1e-6) + else: + # All envs and selected joints + data1 = torch.rand((num_instances, len(joint_ids)), device=device) + data2 = torch.rand((num_instances, len(joint_ids)), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[:, joint_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + data2_warp = torch.ones((num_instances, num_joints), device=device) + data2_warp[:, joint_ids] = data2 + data2_warp = wp.from_torch(data2_warp, dtype=wp.float32) + joint_mask = torch.zeros((num_joints,), dtype=torch.bool, device=device) + joint_mask[joint_ids] = True + joint_mask = wp.from_torch(joint_mask, dtype=wp.bool) + writer_function( + data1_warp, + data2_warp, + env_mask=None, + joint_mask=joint_mask, + ) + data = torch.cat([data1.unsqueeze(-1), data2.unsqueeze(-1)], dim=-1) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data)[:, joint_ids].allclose(data, atol=1e-6, rtol=1e-6) + else: + if joint_ids is None: + # Selected envs and all joints + data1 = torch.rand((len(env_ids), num_joints), device=device) + data2 = torch.rand((len(env_ids), num_joints), device=device) + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[env_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + data2_warp = torch.ones((num_instances, num_joints), device=device) + data2_warp[env_ids] = data2 + data2_warp = wp.from_torch(data2_warp, dtype=wp.float32) + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + writer_function( + data1_warp, + data2_warp, + env_mask=env_mask, + joint_mask=None, + ) + data = torch.cat([data1.unsqueeze(-1), data2.unsqueeze(-1)], dim=-1) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data)[env_ids, :].allclose(data, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + data1 = torch.rand((len(env_ids), len(joint_ids)), device=device) + data2 = torch.rand((len(env_ids), len(joint_ids)), device=device) + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + data1_warp = torch.ones((num_instances, num_joints), device=device) + data1_warp[env_ids_, joint_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=wp.float32) + data2_warp = torch.ones((num_instances, num_joints), device=device) + data2_warp[env_ids_, joint_ids] = data2 + data2_warp = wp.from_torch(data2_warp, dtype=wp.float32) + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + joint_mask = torch.zeros((num_joints,), dtype=torch.bool, device=device) + joint_mask[joint_ids] = True + joint_mask = wp.from_torch(joint_mask, dtype=wp.bool) + writer_function(data1_warp, data2_warp, env_mask=env_mask, joint_mask=joint_mask) + data = torch.cat([data1.unsqueeze(-1), data2.unsqueeze(-1)], dim=-1) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data)[env_ids_, joint_ids].allclose(data, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_stiffness_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_torch( + device, env_ids, joint_ids, num_instances, num_joints, "write_joint_stiffness_to_sim", "joint_stiffness" + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_stiffness_to_sim_warp( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_warp( + device, env_ids, joint_ids, num_instances, num_joints, "write_joint_stiffness_to_sim", "joint_stiffness" + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_damping_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_torch( + device, env_ids, joint_ids, num_instances, num_joints, "write_joint_damping_to_sim", "joint_damping" + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_damping_to_sim_warp( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_warp( + device, env_ids, joint_ids, num_instances, num_joints, "write_joint_damping_to_sim", "joint_damping" + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_velocity_limit_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_torch( + device, + env_ids, + joint_ids, + num_instances, + num_joints, + "write_joint_velocity_limit_to_sim", + "joint_vel_limits", + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_velocity_limit_to_sim_warp( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_warp( + device, + env_ids, + joint_ids, + num_instances, + num_joints, + "write_joint_velocity_limit_to_sim", + "joint_vel_limits", + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_effort_limit_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_torch( + device, + env_ids, + joint_ids, + num_instances, + num_joints, + "write_joint_effort_limit_to_sim", + "joint_effort_limits", + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_effort_limit_to_sim_warp( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_warp( + device, + env_ids, + joint_ids, + num_instances, + num_joints, + "write_joint_effort_limit_to_sim", + "joint_effort_limits", + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_armature_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_torch( + device, env_ids, joint_ids, num_instances, num_joints, "write_joint_armature_to_sim", "joint_armature" + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_armature_to_sim_warp( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_warp( + device, env_ids, joint_ids, num_instances, num_joints, "write_joint_armature_to_sim", "joint_armature" + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_friction_coefficient_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_torch( + device, + env_ids, + joint_ids, + num_instances, + num_joints, + "write_joint_friction_coefficient_to_sim", + "joint_friction_coeff", + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_friction_coefficient_to_sim_warp( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_warp( + device, + env_ids, + joint_ids, + num_instances, + num_joints, + "write_joint_friction_coefficient_to_sim", + "joint_friction_coeff", + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_dynamic_friction_coefficient_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_torch( + device, + env_ids, + joint_ids, + num_instances, + num_joints, + "write_joint_dynamic_friction_coefficient_to_sim", + "joint_dynamic_friction_coeff", + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_dynamic_friction_coefficient_to_sim_warp( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_warp( + device, + env_ids, + joint_ids, + num_instances, + num_joints, + "write_joint_dynamic_friction_coefficient_to_sim", + "joint_dynamic_friction_coeff", + ) + + # TODO: Remove once the deprecated function is removed. + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_friction_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_torch( + device, env_ids, joint_ids, num_instances, num_joints, "write_joint_friction_to_sim", "joint_friction_coeff" + ) + + # TODO: Remove once the deprecated function is removed. + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_friction_to_sim_warp( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_warp( + device, env_ids, joint_ids, num_instances, num_joints, "write_joint_friction_to_sim", "joint_friction_coeff" + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_position_limit_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_torch_dual( + device, + env_ids, + joint_ids, + num_instances, + num_joints, + "write_joint_position_limit_to_sim", + "joint_pos_limits", + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_position_limit_to_sim_warp_dual( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_warp_dual( + device, + env_ids, + joint_ids, + num_instances, + num_joints, + "write_joint_position_limit_to_sim", + "joint_pos_limits", + ) + + # TODO: Remove once the deprecated function is removed. + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_limits_to_sim_torch( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_torch_dual( + device, env_ids, joint_ids, num_instances, num_joints, "write_joint_limits_to_sim", "joint_pos_limits" + ) + + # TODO: Remove once the deprecated function is removed. + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("joint_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_write_joint_limits_to_sim_warp_dual( + self, device: str, env_ids, joint_ids, num_instances: int, num_joints: int + ): + self.generic_test_property_writer_warp_dual( + device, env_ids, joint_ids, num_instances, num_joints, "write_joint_limits_to_sim", "joint_pos_limits" + ) + + +## +# Test Cases - Setters. +## + + +class TestSettersBodiesMassCoMInertia: + """Tests for setter methods that set body mass, center of mass, and inertia. + + Tests methods: + - set_masses + - set_coms + - set_inertias + """ + + def generic_test_property_writer_torch( + self, + device: str, + env_ids, + body_ids, + num_instances: int, + num_bodies: int, + writer_function_name: str, + property_name: str, + dtype: type = wp.float32, + ): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_bodies=num_bodies, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_bodies == 1: + if (body_ids is not None) and (not isinstance(body_ids, slice)): + body_ids = [0] + + writer_function = getattr(articulation, writer_function_name) + if dtype == wp.float32: + ndims = tuple() + elif dtype == wp.vec3f: + ndims = (3,) + elif dtype == wp.mat33f: + ndims = ( + 3, + 3, + ) + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + for _ in range(5): + if (env_ids is None) or (isinstance(env_ids, slice)): + if (body_ids is None) or (isinstance(body_ids, slice)): + # All envs and joints + data1 = torch.rand((num_instances, num_bodies, *ndims), device=device) + writer_function(data1, env_ids=env_ids, body_ids=body_ids) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data1, atol=1e-6, rtol=1e-6) + else: + # All envs and selected bodies + data1 = torch.rand((num_instances, len(body_ids), *ndims), device=device) + data_ref = torch.zeros((num_instances, num_bodies, *ndims), device=device) + data_ref[:, body_ids] = data1 + writer_function(data1, env_ids=env_ids, body_ids=body_ids) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + else: + if (body_ids is None) or (isinstance(body_ids, slice)): + # Selected envs and all bodies + data1 = torch.rand((len(env_ids), num_bodies, *ndims), device=device) + data_ref = torch.zeros((num_instances, num_bodies, *ndims), device=device) + data_ref[env_ids, :] = data1 + writer_function(data1, env_ids=env_ids, body_ids=body_ids) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + else: + # Selected envs and bodies + data1 = torch.rand((len(env_ids), len(body_ids), *ndims), device=device) + writer_function(data1, env_ids=env_ids, body_ids=body_ids) + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + data_ref = torch.zeros((num_instances, num_bodies, *ndims), device=device) + data_ref[env_ids_, body_ids] = data1 + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + + def generic_test_property_writer_warp( + self, + device: str, + env_ids, + body_ids, + num_instances: int, + num_bodies: int, + writer_function_name: str, + property_name: str, + dtype: type = wp.float32, + ): + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, num_bodies=num_bodies, device=device + ) + if num_instances == 1: + if (env_ids is not None) and (not isinstance(env_ids, slice)): + env_ids = [0] + if num_bodies == 1: + if (body_ids is not None) and (not isinstance(body_ids, slice)): + body_ids = [0] + + writer_function = getattr(articulation, writer_function_name) + if dtype == wp.float32: + ndims = tuple() + elif dtype == wp.vec3f: + ndims = (3,) + elif dtype == wp.mat33f: + ndims = ( + 3, + 3, + ) + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + for _ in range(5): + if env_ids is None: + if body_ids is None: + # All envs and joints + data1 = torch.rand((num_instances, num_bodies, *ndims), device=device) + data1_warp = wp.from_torch(data1, dtype=dtype) + writer_function(data1_warp, env_mask=None, body_mask=None) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data1, atol=1e-6, rtol=1e-6) + else: + # All envs and selected joints + data1 = torch.rand((num_instances, len(body_ids), *ndims), device=device) + data1_warp = torch.ones((num_instances, num_bodies, *ndims), device=device) + data1_warp[:, body_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=dtype) + body_mask = torch.zeros((num_bodies,), dtype=torch.bool, device=device) + body_mask[body_ids] = True + body_mask = wp.from_torch(body_mask, dtype=wp.bool) + data_ref = torch.zeros((num_instances, num_bodies, *ndims), device=device) + data_ref[:, body_ids] = data1 + writer_function(data1_warp, env_mask=None, body_mask=body_mask) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + else: + if body_ids is None: + # Selected envs and all joints + data1 = torch.rand((len(env_ids), num_bodies, *ndims), device=device) + data1_warp = torch.ones((num_instances, num_bodies, *ndims), device=device) + data1_warp[env_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=dtype) + data_ref = torch.zeros((num_instances, num_bodies, *ndims), device=device) + data_ref[env_ids, :] = data1 + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + writer_function(data1_warp, env_mask=env_mask, body_mask=None) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + else: + # Selected envs and joints + env_ids_ = torch.tensor(env_ids, dtype=torch.int32, device=device) + env_ids_ = env_ids_[:, None] + data1 = torch.rand((len(env_ids), len(body_ids), *ndims), device=device) + data1_warp = torch.ones((num_instances, num_bodies, *ndims), device=device) + data1_warp[env_ids_, body_ids] = data1 + data1_warp = wp.from_torch(data1_warp, dtype=dtype) + data_ref = torch.zeros((num_instances, num_bodies, *ndims), device=device) + data_ref[env_ids_, body_ids] = data1 + env_mask = torch.zeros((num_instances,), dtype=torch.bool, device=device) + env_mask[env_ids] = True + env_mask = wp.from_torch(env_mask, dtype=wp.bool) + body_mask = torch.zeros((num_bodies,), dtype=torch.bool, device=device) + body_mask[body_ids] = True + body_mask = wp.from_torch(body_mask, dtype=wp.bool) + writer_function(data1_warp, env_mask=env_mask, body_mask=body_mask) + property_data = getattr(articulation.data, property_name) + assert wp.to_torch(property_data).allclose(data_ref, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("body_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_bodies", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_set_masses_to_sim_torch(self, device: str, env_ids, body_ids, num_instances: int, num_bodies: int): + self.generic_test_property_writer_torch( + device, env_ids, body_ids, num_instances, num_bodies, "set_masses", "body_mass", dtype=wp.float32 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("body_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_bodies", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_set_masses_to_sim_warp(self, device: str, env_ids, body_ids, num_instances: int, num_bodies: int): + self.generic_test_property_writer_warp( + device, env_ids, body_ids, num_instances, num_bodies, "set_masses", "body_mass", dtype=wp.float32 + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("body_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_bodies", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_set_coms_to_sim_torch(self, device: str, env_ids, body_ids, num_instances: int, num_bodies: int): + self.generic_test_property_writer_torch( + device, env_ids, body_ids, num_instances, num_bodies, "set_coms", "body_com_pos_b", dtype=wp.vec3f + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("body_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_bodies", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_set_coms_to_sim_warp(self, device: str, env_ids, body_ids, num_instances: int, num_bodies: int): + self.generic_test_property_writer_warp( + device, env_ids, body_ids, num_instances, num_bodies, "set_coms", "body_com_pos_b", dtype=wp.vec3f + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("body_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_bodies", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_set_inertias_to_sim_torch(self, device: str, env_ids, body_ids, num_instances: int, num_bodies: int): + self.generic_test_property_writer_torch( + device, env_ids, body_ids, num_instances, num_bodies, "set_inertias", "body_inertia", dtype=wp.mat33f + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("env_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("body_ids", [None, [0, 1, 2], [0]]) + @pytest.mark.parametrize("num_bodies", [1, 6]) + @pytest.mark.parametrize("num_instances", [1, 4]) + def test_set_inertias_to_sim_warp(self, device: str, env_ids, body_ids, num_instances: int, num_bodies: int): + self.generic_test_property_writer_warp( + device, env_ids, body_ids, num_instances, num_bodies, "set_inertias", "body_inertia", dtype=wp.mat33f + ) + + +# TODO: Implement these tests once the Wrench Composers made it to main IsaacLab. +class TestSettersExternalWrench: + """Tests for setter methods that set external wrench. + + Tests methods: + - set_external_force_and_torque + """ + + @pytest.mark.skip(reason="Not implemented") + def test_external_force_and_torque_to_sim_torch( + self, device: str, env_ids, body_ids, num_instances: int, num_bodies: int + ): + raise NotImplementedError() + + @pytest.mark.skip(reason="Not implemented") + def test_external_force_and_torque_to_sim_warp( + self, device: str, env_ids, body_ids, num_instances: int, num_bodies: int + ): + raise NotImplementedError() + + +class TestFixedTendonsSetters: + """Tests for setter methods that set fixed tendon properties. + + Tests methods: + - set_fixed_tendon_stiffness + - set_fixed_tendon_damping + - set_fixed_tendon_limit_stiffness + - set_fixed_tendon_position_limit + - set_fixed_tendon_limit (deprecated) + - set_fixed_tendon_rest_length + - set_fixed_tendon_offset + - write_fixed_tendon_properties_to_sim + """ + + def test_set_fixed_tendon_stiffness_not_implemented(self): + """Test that set_fixed_tendon_stiffness raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + stiffness = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_fixed_tendon_stiffness(stiffness) + + def test_set_fixed_tendon_damping_not_implemented(self): + """Test that set_fixed_tendon_damping raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + damping = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_fixed_tendon_damping(damping) + + def test_set_fixed_tendon_limit_stiffness_not_implemented(self): + """Test that set_fixed_tendon_limit_stiffness raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + limit_stiffness = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_fixed_tendon_limit_stiffness(limit_stiffness) + + def test_set_fixed_tendon_position_limit_not_implemented(self): + """Test that set_fixed_tendon_position_limit raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + limit = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_fixed_tendon_position_limit(limit) + + def test_set_fixed_tendon_limit_not_implemented(self): + """Test that set_fixed_tendon_limit (deprecated) raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + limit = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_fixed_tendon_limit(limit) + + def test_set_fixed_tendon_rest_length_not_implemented(self): + """Test that set_fixed_tendon_rest_length raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + rest_length = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_fixed_tendon_rest_length(rest_length) + + def test_set_fixed_tendon_offset_not_implemented(self): + """Test that set_fixed_tendon_offset raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + offset = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_fixed_tendon_offset(offset) + + def test_write_fixed_tendon_properties_to_sim_not_implemented(self): + """Test that write_fixed_tendon_properties_to_sim raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + with pytest.raises(NotImplementedError): + articulation.write_fixed_tendon_properties_to_sim() + + +class TestSpatialTendonsSetters: + """Tests for setter methods that set spatial tendon properties. + + Tests methods: + - set_spatial_tendon_stiffness + - set_spatial_tendon_damping + - set_spatial_tendon_limit_stiffness + - set_spatial_tendon_offset + - write_spatial_tendon_properties_to_sim + """ + + def test_set_spatial_tendon_stiffness_not_implemented(self): + """Test that set_spatial_tendon_stiffness raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + stiffness = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_spatial_tendon_stiffness(stiffness) + + def test_set_spatial_tendon_damping_not_implemented(self): + """Test that set_spatial_tendon_damping raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + damping = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_spatial_tendon_damping(damping) + + def test_set_spatial_tendon_limit_stiffness_not_implemented(self): + """Test that set_spatial_tendon_limit_stiffness raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + limit_stiffness = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_spatial_tendon_limit_stiffness(limit_stiffness) + + def test_set_spatial_tendon_offset_not_implemented(self): + """Test that set_spatial_tendon_offset raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + offset = wp.zeros((2, 1), dtype=wp.float32, device="cuda:0") + with pytest.raises(NotImplementedError): + articulation.set_spatial_tendon_offset(offset) + + def test_write_spatial_tendon_properties_to_sim_not_implemented(self): + """Test that write_spatial_tendon_properties_to_sim raises NotImplementedError.""" + articulation, _, _ = create_test_articulation() + with pytest.raises(NotImplementedError): + articulation.write_spatial_tendon_properties_to_sim() + + +class TestCreateBuffers: + """Tests for _create_buffers method. + + Tests that the buffers are created correctly: + - _ALL_INDICES tensor contains correct indices for varying number of environments + - soft_joint_pos_limits are correctly computed based on soft_joint_pos_limit_factor + """ + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + @pytest.mark.parametrize("num_instances", [1, 2, 4, 10, 100]) + def test_create_buffers_all_indices(self, device: str, num_instances: int): + """Test that _ALL_INDICES contains correct indices for varying number of environments.""" + num_joints = 6 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set up joint limits (required for _create_buffers) + joint_limit_lower = torch.full((num_instances, num_joints), -1.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 1.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Call _create_buffers + articulation._create_buffers() + + # Verify _ALL_INDICES + expected_indices = torch.arange(num_instances, dtype=torch.long, device=device) + assert articulation._ALL_INDICES.shape == (num_instances,) + assert articulation._ALL_INDICES.dtype == torch.long + assert articulation._ALL_INDICES.device.type == device.split(":")[0] + torch.testing.assert_close(articulation._ALL_INDICES, expected_indices) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_create_buffers_soft_joint_limits_factor_1(self, device: str): + """Test soft_joint_pos_limits with factor=1.0 (limits unchanged).""" + num_instances = 2 + num_joints = 4 + soft_joint_pos_limit_factor = 1.0 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + soft_joint_pos_limit_factor=soft_joint_pos_limit_factor, + device=device, + ) + + # Set up joint limits: [-2.0, 2.0] for all joints + joint_limit_lower = torch.full((num_instances, num_joints), -2.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 2.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Call _create_buffers + articulation._create_buffers() + + # With factor=1.0, soft limits should equal hard limits + # soft_joint_pos_limits is wp.vec2f (lower, upper) + soft_limits = wp.to_torch(articulation.data.soft_joint_pos_limits) + # Shape is (num_instances, num_joints, 2) after conversion + expected_lower = torch.full((num_instances, num_joints), -2.0, device=device) + expected_upper = torch.full((num_instances, num_joints), 2.0, device=device) + torch.testing.assert_close(soft_limits[:, :, 0], expected_lower, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(soft_limits[:, :, 1], expected_upper, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_create_buffers_soft_joint_limits_factor_half(self, device: str): + """Test soft_joint_pos_limits with factor=0.5 (limits halved around mean).""" + num_instances = 2 + num_joints = 4 + soft_joint_pos_limit_factor = 0.5 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + soft_joint_pos_limit_factor=soft_joint_pos_limit_factor, + device=device, + ) + + # Set up joint limits: [-2.0, 2.0] for all joints + # mean = 0.0, range = 4.0 + # soft_lower = 0.0 - 0.5 * 4.0 * 0.5 = -1.0 + # soft_upper = 0.0 + 0.5 * 4.0 * 0.5 = 1.0 + joint_limit_lower = torch.full((num_instances, num_joints), -2.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 2.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Call _create_buffers + articulation._create_buffers() + + # Verify soft limits are halved + soft_limits = wp.to_torch(articulation.data.soft_joint_pos_limits) + expected_lower = torch.full((num_instances, num_joints), -1.0, device=device) + expected_upper = torch.full((num_instances, num_joints), 1.0, device=device) + torch.testing.assert_close(soft_limits[:, :, 0], expected_lower, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(soft_limits[:, :, 1], expected_upper, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_create_buffers_soft_joint_limits_asymmetric(self, device: str): + """Test soft_joint_pos_limits with asymmetric joint limits.""" + num_instances = 2 + num_joints = 3 + soft_joint_pos_limit_factor = 0.8 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + soft_joint_pos_limit_factor=soft_joint_pos_limit_factor, + device=device, + ) + + # Set up asymmetric joint limits + # Joint 0: [-3.14, 3.14] -> mean=0, range=6.28 -> soft: [-2.512, 2.512] + # Joint 1: [-1.0, 2.0] -> mean=0.5, range=3.0 -> soft: [0.5-1.2, 0.5+1.2] = [-0.7, 1.7] + # Joint 2: [0.0, 1.0] -> mean=0.5, range=1.0 -> soft: [0.5-0.4, 0.5+0.4] = [0.1, 0.9] + joint_limit_lower = torch.tensor([[-3.14, -1.0, 0.0]] * num_instances, device=device) + joint_limit_upper = torch.tensor([[3.14, 2.0, 1.0]] * num_instances, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Call _create_buffers + articulation._create_buffers() + + # Calculate expected soft limits + # soft_lower = mean - 0.5 * range * factor + # soft_upper = mean + 0.5 * range * factor + expected_lower = torch.tensor([[-2.512, -0.7, 0.1]] * num_instances, device=device) + expected_upper = torch.tensor([[2.512, 1.7, 0.9]] * num_instances, device=device) + + soft_limits = wp.to_torch(articulation.data.soft_joint_pos_limits) + torch.testing.assert_close(soft_limits[:, :, 0], expected_lower, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(soft_limits[:, :, 1], expected_upper, atol=1e-3, rtol=1e-3) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_create_buffers_soft_joint_limits_factor_zero(self, device: str): + """Test soft_joint_pos_limits with factor=0.0 (limits collapse to mean).""" + num_instances = 2 + num_joints = 4 + soft_joint_pos_limit_factor = 0.0 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + soft_joint_pos_limit_factor=soft_joint_pos_limit_factor, + device=device, + ) + + # Set up joint limits: [-2.0, 2.0] + # mean = 0.0, with factor=0.0, soft limits collapse to [0, 0] + joint_limit_lower = torch.full((num_instances, num_joints), -2.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 2.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Call _create_buffers + articulation._create_buffers() + + # With factor=0.0, soft limits should collapse to the mean + soft_limits = wp.to_torch(articulation.data.soft_joint_pos_limits) + expected_lower = torch.full((num_instances, num_joints), 0.0, device=device) + expected_upper = torch.full((num_instances, num_joints), 0.0, device=device) + torch.testing.assert_close(soft_limits[:, :, 0], expected_lower, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(soft_limits[:, :, 1], expected_upper, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_create_buffers_soft_joint_limits_per_joint_different(self, device: str): + """Test soft_joint_pos_limits with different limits per joint.""" + num_instances = 3 + num_joints = 4 + soft_joint_pos_limit_factor = 0.9 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + soft_joint_pos_limit_factor=soft_joint_pos_limit_factor, + device=device, + ) + + # Each joint has different limits + joint_limit_lower = torch.tensor([[-1.0, -2.0, -0.5, -3.0]] * num_instances, device=device) + joint_limit_upper = torch.tensor([[1.0, 2.0, 0.5, 3.0]] * num_instances, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Call _create_buffers + articulation._create_buffers() + + # Calculate expected: soft_lower/upper = mean ± 0.5 * range * factor + # Joint 0: mean=0, range=2 -> [0 - 0.9, 0 + 0.9] = [-0.9, 0.9] + # Joint 1: mean=0, range=4 -> [0 - 1.8, 0 + 1.8] = [-1.8, 1.8] + # Joint 2: mean=0, range=1 -> [0 - 0.45, 0 + 0.45] = [-0.45, 0.45] + # Joint 3: mean=0, range=6 -> [0 - 2.7, 0 + 2.7] = [-2.7, 2.7] + expected_lower = torch.tensor([[-0.9, -1.8, -0.45, -2.7]] * num_instances, device=device) + expected_upper = torch.tensor([[0.9, 1.8, 0.45, 2.7]] * num_instances, device=device) + + soft_limits = wp.to_torch(articulation.data.soft_joint_pos_limits) + torch.testing.assert_close(soft_limits[:, :, 0], expected_lower, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(soft_limits[:, :, 1], expected_upper, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_create_buffers_single_environment(self, device: str): + """Test _create_buffers with a single environment.""" + num_instances = 1 + num_joints = 6 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + joint_limit_lower = torch.full((num_instances, num_joints), -1.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 1.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Call _create_buffers + articulation._create_buffers() + + # Verify _ALL_INDICES has single element + assert articulation._ALL_INDICES.shape == (1,) + assert articulation._ALL_INDICES[0].item() == 0 + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_create_buffers_large_number_of_environments(self, device: str): + """Test _create_buffers with a large number of environments.""" + num_instances = 1024 + num_joints = 12 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + joint_limit_lower = torch.full((num_instances, num_joints), -1.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 1.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Call _create_buffers + articulation._create_buffers() + + # Verify _ALL_INDICES + expected_indices = torch.arange(num_instances, dtype=torch.long, device=device) + assert articulation._ALL_INDICES.shape == (num_instances,) + torch.testing.assert_close(articulation._ALL_INDICES, expected_indices) + + # Verify soft limits shape + soft_limits = wp.to_torch(articulation.data.soft_joint_pos_limits) + assert soft_limits.shape == (num_instances, num_joints, 2) + + +class TestProcessCfg: + """Tests for _process_cfg method. + + Tests that the configuration processing correctly: + - Converts quaternion from (w, x, y, z) to (x, y, z, w) format for default root pose + - Sets default root velocity from lin_vel and ang_vel + - Sets default joint positions from joint_pos dict with pattern matching + - Sets default joint velocities from joint_vel dict with pattern matching + """ + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_default_root_pose(self, device: str): + """Test that _process_cfg correctly converts quaternion format for root pose.""" + num_instances = 2 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set up init_state with specific position and rotation + # Rotation is in (w, x, y, z) format in the config + articulation.cfg.init_state.pos = (1.0, 2.0, 3.0) + articulation.cfg.init_state.rot = (0.707, 0.0, 0.707, 0.0) # w, x, y, z + + # Call _process_cfg + articulation._process_cfg() + + # Verify the default root pose + # Expected: position (1, 2, 3) + quaternion converted to (x, y, z, w) = (0, 0.707, 0, 0.707) + expected_pose = torch.tensor( + [[1.0, 2.0, 3.0, 0.0, 0.707, 0.0, 0.707]] * num_instances, + device=device, + ) + result = wp.to_torch(articulation.data.default_root_pose) + assert result.allclose(expected_pose, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_default_root_velocity(self, device: str): + """Test that _process_cfg correctly sets default root velocity.""" + num_instances = 2 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set up init_state with specific velocities + articulation.cfg.init_state.lin_vel = (1.0, 2.0, 3.0) + articulation.cfg.init_state.ang_vel = (0.1, 0.2, 0.3) + + # Call _process_cfg + articulation._process_cfg() + + # Verify the default root velocity + # Expected: lin_vel + ang_vel = (1, 2, 3, 0.1, 0.2, 0.3) + expected_vel = torch.tensor( + [[1.0, 2.0, 3.0, 0.1, 0.2, 0.3]] * num_instances, + device=device, + ) + result = wp.to_torch(articulation.data.default_root_vel) + assert result.allclose(expected_vel, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_default_joint_positions_all_joints(self, device: str): + """Test that _process_cfg correctly sets default joint positions for all joints.""" + num_instances = 2 + num_joints = 4 + joint_names = ["joint_0", "joint_1", "joint_2", "joint_3"] + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + joint_names=joint_names, + device=device, + ) + + # Set up init_state with joint positions using wildcard pattern + articulation.cfg.init_state.joint_pos = {".*": 0.5} + articulation.cfg.init_state.joint_vel = {".*": 0.0} + + # Call _process_cfg + articulation._process_cfg() + + # Verify the default joint positions + expected_pos = torch.full((num_instances, num_joints), 0.5, device=device) + result = wp.to_torch(articulation.data.default_joint_pos) + assert result.allclose(expected_pos, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_default_joint_positions_specific_joints(self, device: str): + """Test that _process_cfg correctly sets default joint positions for specific joints.""" + num_instances = 2 + num_joints = 4 + joint_names = ["shoulder", "elbow", "wrist", "gripper"] + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + joint_names=joint_names, + device=device, + ) + + # Set up init_state with specific joint positions + articulation.cfg.init_state.joint_pos = { + "shoulder": 1.0, + "elbow": 2.0, + "wrist": 3.0, + "gripper": 4.0, + } + articulation.cfg.init_state.joint_vel = {".*": 0.0} + + # Call _process_cfg + articulation._process_cfg() + + # Verify the default joint positions + expected_pos = torch.tensor([[1.0, 2.0, 3.0, 4.0]] * num_instances, device=device) + result = wp.to_torch(articulation.data.default_joint_pos) + assert result.allclose(expected_pos, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_default_joint_positions_regex_pattern(self, device: str): + """Test that _process_cfg correctly handles regex patterns for joint positions.""" + num_instances = 2 + num_joints = 6 + joint_names = ["arm_joint_1", "arm_joint_2", "arm_joint_3", "hand_joint_1", "hand_joint_2", "hand_joint_3"] + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + joint_names=joint_names, + device=device, + ) + + # Set up init_state with regex patterns + articulation.cfg.init_state.joint_pos = { + "arm_joint_.*": 1.5, + "hand_joint_.*": 0.5, + } + articulation.cfg.init_state.joint_vel = {".*": 0.0} + + # Call _process_cfg + articulation._process_cfg() + + # Verify the default joint positions + # arm joints (indices 0-2) should be 1.5, hand joints (indices 3-5) should be 0.5 + expected_pos = torch.tensor([[1.5, 1.5, 1.5, 0.5, 0.5, 0.5]] * num_instances, device=device) + result = wp.to_torch(articulation.data.default_joint_pos) + assert result.allclose(expected_pos, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_default_joint_velocities(self, device: str): + """Test that _process_cfg correctly sets default joint velocities.""" + num_instances = 2 + num_joints = 4 + joint_names = ["joint_0", "joint_1", "joint_2", "joint_3"] + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + joint_names=joint_names, + device=device, + ) + + # Set up init_state with joint velocities + articulation.cfg.init_state.joint_pos = {".*": 0.0} + articulation.cfg.init_state.joint_vel = { + "joint_0": 0.1, + "joint_1": 0.2, + "joint_2": 0.3, + "joint_3": 0.4, + } + + # Call _process_cfg + articulation._process_cfg() + + # Verify the default joint velocities + expected_vel = torch.tensor([[0.1, 0.2, 0.3, 0.4]] * num_instances, device=device) + result = wp.to_torch(articulation.data.default_joint_vel) + assert result.allclose(expected_vel, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_identity_quaternion(self, device: str): + """Test that _process_cfg correctly handles identity quaternion.""" + num_instances = 2 + num_joints = 2 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set up init_state with identity quaternion (w=1, x=0, y=0, z=0) + articulation.cfg.init_state.pos = (0.0, 0.0, 0.0) + articulation.cfg.init_state.rot = (1.0, 0.0, 0.0, 0.0) # Identity: w, x, y, z + + # Call _process_cfg + articulation._process_cfg() + + # Verify the default root pose + # Expected: position (0, 0, 0) + quaternion converted to (x, y, z, w) = (0, 0, 0, 1) + expected_pose = torch.tensor( + [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]] * num_instances, + device=device, + ) + result = wp.to_torch(articulation.data.default_root_pose) + assert result.allclose(expected_pose, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_zero_joints(self, device: str): + """Test that _process_cfg handles articulation with no joints.""" + num_instances = 2 + num_joints = 0 + num_bodies = 1 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + num_bodies=num_bodies, + device=device, + ) + + # Set up init_state + articulation.cfg.init_state.pos = (1.0, 2.0, 3.0) + articulation.cfg.init_state.rot = (1.0, 0.0, 0.0, 0.0) + articulation.cfg.init_state.lin_vel = (0.5, 0.5, 0.5) + articulation.cfg.init_state.ang_vel = (0.1, 0.1, 0.1) + articulation.cfg.init_state.joint_pos = {} + articulation.cfg.init_state.joint_vel = {} + + # Call _process_cfg - should not raise any exception + articulation._process_cfg() + + # Verify root pose and velocity are still set correctly + expected_pose = torch.tensor( + [[1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 1.0]] * num_instances, + device=device, + ) + expected_vel = torch.tensor( + [[0.5, 0.5, 0.5, 0.1, 0.1, 0.1]] * num_instances, + device=device, + ) + assert wp.to_torch(articulation.data.default_root_pose).allclose(expected_pose, atol=1e-5, rtol=1e-5) + assert wp.to_torch(articulation.data.default_root_vel).allclose(expected_vel, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_mixed_joint_patterns(self, device: str): + """Test that _process_cfg handles mixed specific and pattern-based joint settings.""" + num_instances = 2 + num_joints = 5 + joint_names = ["base_joint", "arm_1", "arm_2", "hand_1", "hand_2"] + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + joint_names=joint_names, + device=device, + ) + + # Set up init_state with mixed patterns + articulation.cfg.init_state.joint_pos = { + "base_joint": 0.0, + "arm_.*": 1.0, + "hand_.*": 2.0, + } + articulation.cfg.init_state.joint_vel = {".*": 0.0} + + # Call _process_cfg + articulation._process_cfg() + + # Verify the default joint positions + expected_pos = torch.tensor([[0.0, 1.0, 1.0, 2.0, 2.0]] * num_instances, device=device) + result = wp.to_torch(articulation.data.default_joint_pos) + assert result.allclose(expected_pos, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_offsets_spawned_pose(self, device: str): + """Test that _process_cfg offsets the spawned position by the default root pose.""" + num_instances = 3 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set up default root pose in config: position (1.0, 2.0, 3.0), identity quaternion + articulation.cfg.init_state.pos = (1.0, 2.0, 3.0) + articulation.cfg.init_state.rot = (1.0, 0.0, 0.0, 0.0) # w, x, y, z (identity) + + # Set up initial spawned positions for each instance + # Instance 0: (5.0, 6.0, 0.0) + # Instance 1: (10.0, 20.0, 0.0) + # Instance 2: (-3.0, -4.0, 0.0) + spawned_transforms = torch.tensor( + [ + [5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 1.0], # pos (x,y,z), quat (x,y,z,w) + [10.0, 20.0, 0.0, 0.0, 0.0, 0.0, 1.0], + [-3.0, -4.0, 0.0, 0.0, 0.0, 0.0, 1.0], + ], + device=device, + ) + mock_view.set_mock_data( + root_transforms=wp.from_torch(spawned_transforms, dtype=wp.transformf), + ) + + # Call _process_cfg + articulation._process_cfg() + + # Verify that the root transforms are offset by default pose's x,y + # Expected: spawned_pose[:, :2] + default_pose[:2] + # Instance 0: (5.0 + 1.0, 6.0 + 2.0, 3.0) = (6.0, 8.0, 3.0) + # Instance 1: (10.0 + 1.0, 20.0 + 2.0, 3.0) = (11.0, 22.0, 3.0) + # Instance 2: (-3.0 + 1.0, -4.0 + 2.0, 3.0) = (-2.0, -2.0, 3.0) + result = wp.to_torch(mock_view.get_root_transforms(None)) + expected_transforms = torch.tensor( + [ + [6.0, 8.0, 3.0, 0.0, 0.0, 0.0, 1.0], + [11.0, 22.0, 3.0, 0.0, 0.0, 0.0, 1.0], + [-2.0, -2.0, 3.0, 0.0, 0.0, 0.0, 1.0], + ], + device=device, + ) + torch.testing.assert_close(result, expected_transforms, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_offsets_spawned_pose_zero_offset(self, device: str): + """Test that _process_cfg with zero default position keeps spawned position unchanged in x,y.""" + num_instances = 2 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set up default root pose with zero position + articulation.cfg.init_state.pos = (0.0, 0.0, 0.0) + articulation.cfg.init_state.rot = (1.0, 0.0, 0.0, 0.0) + + # Set up initial spawned positions + spawned_transforms = torch.tensor( + [ + [5.0, 6.0, 7.0, 0.0, 0.0, 0.0, 1.0], + [10.0, 20.0, 30.0, 0.0, 0.0, 0.0, 1.0], + ], + device=device, + ) + mock_view.set_mock_data( + root_transforms=wp.from_torch(spawned_transforms, dtype=wp.transformf), + ) + + # Call _process_cfg + articulation._process_cfg() + + # With zero default position, x,y should stay the same, z comes from default (0.0) + result = wp.to_torch(mock_view.get_root_transforms(None)) + expected_transforms = torch.tensor( + [ + [5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 1.0], + [10.0, 20.0, 0.0, 0.0, 0.0, 0.0, 1.0], + ], + device=device, + ) + torch.testing.assert_close(result, expected_transforms, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_cfg_offsets_spawned_pose_with_rotation(self, device: str): + """Test that _process_cfg correctly sets rotation while offsetting position.""" + num_instances = 2 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set up default root pose with specific rotation (90 degrees around z-axis) + # Quaternion for 90 degrees around z: (w=0.707, x=0, y=0, z=0.707) + articulation.cfg.init_state.pos = (1.0, 2.0, 5.0) + articulation.cfg.init_state.rot = (0.707, 0.0, 0.0, 0.707) # w, x, y, z + + # Set up initial spawned positions + spawned_transforms = torch.tensor( + [ + [3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 1.0], + [6.0, 8.0, 0.0, 0.0, 0.0, 0.0, 1.0], + ], + device=device, + ) + mock_view.set_mock_data( + root_transforms=wp.from_torch(spawned_transforms, dtype=wp.transformf), + ) + + # Call _process_cfg + articulation._process_cfg() + + # Verify position offset and rotation is set correctly + # Position: spawned[:2] + default[:2], z from default + # Rotation: from default (converted to x,y,z,w format) + result = wp.to_torch(mock_view.get_root_transforms(None)) + expected_transforms = torch.tensor( + [ + [4.0, 6.0, 5.0, 0.0, 0.0, 0.707, 0.707], # x,y,z, qx,qy,qz,qw + [7.0, 10.0, 5.0, 0.0, 0.0, 0.707, 0.707], + ], + device=device, + ) + torch.testing.assert_close(result, expected_transforms, atol=1e-3, rtol=1e-3) + + +class TestValidateCfg: + """Tests for _validate_cfg method. + + Tests that the configuration validation correctly catches: + - Default joint positions outside of joint limits (lower and upper bounds) + - Various edge cases with joint limits + """ + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_validate_cfg_positions_within_limits(self, device: str): + """Test that _validate_cfg passes when all default positions are within limits.""" + num_instances = 2 + num_joints = 6 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set joint limits: [-1.0, 1.0] for all joints + joint_limit_lower = torch.full((num_instances, num_joints), -1.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 1.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Set default joint positions within limits + default_joint_pos = torch.zeros((num_instances, num_joints), device=device) + articulation.data._default_joint_pos = wp.from_torch(default_joint_pos, dtype=wp.float32) + + # Should not raise any exception + articulation._validate_cfg() + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_validate_cfg_position_below_lower_limit(self, device: str): + """Test that _validate_cfg raises ValueError when a position is below the lower limit.""" + num_instances = 2 + num_joints = 6 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set joint limits: [-1.0, 1.0] for all joints + joint_limit_lower = torch.full((num_instances, num_joints), -1.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 1.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Set default joint position for joint 2 below the lower limit + default_joint_pos = torch.zeros((num_instances, num_joints), device=device) + default_joint_pos[:, 2] = -1.5 # Below -1.0 lower limit + articulation.data._default_joint_pos = wp.from_torch(default_joint_pos, dtype=wp.float32) + + # Should raise ValueError + with pytest.raises(ValueError) as exc_info: + articulation._validate_cfg() + assert "joint_2" in str(exc_info.value) + assert "-1.500" in str(exc_info.value) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_validate_cfg_position_above_upper_limit(self, device: str): + """Test that _validate_cfg raises ValueError when a position is above the upper limit.""" + num_instances = 2 + num_joints = 6 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set joint limits: [-1.0, 1.0] for all joints + joint_limit_lower = torch.full((num_instances, num_joints), -1.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 1.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Set default joint position for joint 4 above the upper limit + default_joint_pos = torch.zeros((num_instances, num_joints), device=device) + default_joint_pos[:, 4] = 1.5 # Above 1.0 upper limit + articulation.data._default_joint_pos = wp.from_torch(default_joint_pos, dtype=wp.float32) + + # Should raise ValueError + with pytest.raises(ValueError) as exc_info: + articulation._validate_cfg() + assert "joint_4" in str(exc_info.value) + assert "1.500" in str(exc_info.value) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_validate_cfg_multiple_positions_out_of_limits(self, device: str): + """Test that _validate_cfg reports all joints with positions outside limits.""" + num_instances = 2 + num_joints = 6 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set joint limits: [-1.0, 1.0] for all joints + joint_limit_lower = torch.full((num_instances, num_joints), -1.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 1.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Set multiple joints out of limits + default_joint_pos = torch.zeros((num_instances, num_joints), device=device) + default_joint_pos[:, 0] = -2.0 # Below lower limit + default_joint_pos[:, 3] = 2.0 # Above upper limit + default_joint_pos[:, 5] = -1.5 # Below lower limit + articulation.data._default_joint_pos = wp.from_torch(default_joint_pos, dtype=wp.float32) + + # Should raise ValueError mentioning all violated joints + with pytest.raises(ValueError) as exc_info: + articulation._validate_cfg() + error_msg = str(exc_info.value) + assert "joint_0" in error_msg + assert "joint_3" in error_msg + assert "joint_5" in error_msg + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_validate_cfg_asymmetric_limits(self, device: str): + """Test that _validate_cfg works with asymmetric joint limits.""" + num_instances = 2 + num_joints = 4 + joint_names = ["shoulder", "elbow", "wrist", "gripper"] + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + joint_names=joint_names, + device=device, + ) + + # Set asymmetric joint limits for each joint + joint_limit_lower = torch.tensor([[-3.14, -2.0, -1.5, 0.0]] * num_instances, device=device) + joint_limit_upper = torch.tensor([[3.14, 0.5, 1.5, 0.1]] * num_instances, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Set positions within asymmetric limits + default_joint_pos = torch.tensor([[0.0, -1.0, 0.0, 0.05]] * num_instances, device=device) + articulation.data._default_joint_pos = wp.from_torch(default_joint_pos, dtype=wp.float32) + + # Should not raise any exception + articulation._validate_cfg() + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_validate_cfg_asymmetric_limits_violated(self, device: str): + """Test that _validate_cfg catches violations with asymmetric limits.""" + num_instances = 2 + num_joints = 4 + joint_names = ["shoulder", "elbow", "wrist", "gripper"] + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + joint_names=joint_names, + device=device, + ) + + # Set asymmetric joint limits: elbow has range [-2.0, 0.5] + joint_limit_lower = torch.tensor([[-3.14, -2.0, -1.5, 0.0]] * num_instances, device=device) + joint_limit_upper = torch.tensor([[3.14, 0.5, 1.5, 0.1]] * num_instances, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Set elbow position above its upper limit (0.5) + default_joint_pos = torch.tensor([[0.0, 1.0, 0.0, 0.05]] * num_instances, device=device) + articulation.data._default_joint_pos = wp.from_torch(default_joint_pos, dtype=wp.float32) + + # Should raise ValueError for elbow + with pytest.raises(ValueError) as exc_info: + articulation._validate_cfg() + assert "elbow" in str(exc_info.value) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_validate_cfg_single_joint(self, device: str): + """Test _validate_cfg with a single joint articulation.""" + num_instances = 2 + num_joints = 1 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set joint limits + joint_limit_lower = torch.full((num_instances, num_joints), -0.5, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), 0.5, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Set position outside limits + default_joint_pos = torch.full((num_instances, num_joints), 1.0, device=device) + articulation.data._default_joint_pos = wp.from_torch(default_joint_pos, dtype=wp.float32) + + # Should raise ValueError + with pytest.raises(ValueError) as exc_info: + articulation._validate_cfg() + assert "joint_0" in str(exc_info.value) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_validate_cfg_negative_range_limits(self, device: str): + """Test _validate_cfg with limits entirely in the negative range.""" + num_instances = 2 + num_joints = 2 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Set limits entirely in negative range + joint_limit_lower = torch.full((num_instances, num_joints), -5.0, device=device) + joint_limit_upper = torch.full((num_instances, num_joints), -2.0, device=device) + mock_view.set_mock_data( + joint_limit_lower=wp.from_torch(joint_limit_lower, dtype=wp.float32), + joint_limit_upper=wp.from_torch(joint_limit_upper, dtype=wp.float32), + ) + + # Set position at zero (outside negative-only limits) + default_joint_pos = torch.zeros((num_instances, num_joints), device=device) + articulation.data._default_joint_pos = wp.from_torch(default_joint_pos, dtype=wp.float32) + + # Should raise ValueError + with pytest.raises(ValueError) as exc_info: + articulation._validate_cfg() + # Both joints should be reported as violated + assert "joint_0" in str(exc_info.value) + assert "joint_1" in str(exc_info.value) + + +# TODO: Expand these tests when tendons are available in Newton. +# Currently, tendons are not implemented and _process_tendons only initializes empty lists. +# When tendon support is added, tests should verify: +# - Fixed tendon properties are correctly parsed and stored +# - Spatial tendon properties are correctly parsed and stored +# - Tendon limits and stiffness values are correctly set +class TestProcessTendons: + """Tests for _process_tendons method. + + Note: Tendons are not yet implemented in Newton. These tests verify the current + placeholder behavior. When tendons are implemented, these tests should be expanded. + """ + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_tendons_initializes_empty_lists(self, device: str): + """Test that _process_tendons initializes empty tendon name lists.""" + num_instances = 2 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Call _process_tendons + articulation._process_tendons() + + # Verify empty lists are created + assert articulation._fixed_tendon_names == [] + assert articulation._spatial_tendon_names == [] + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_tendons_returns_none(self, device: str): + """Test that _process_tendons returns None (no tendons implemented).""" + num_instances = 2 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Call _process_tendons and verify return value + result = articulation._process_tendons() + assert result is None + + +# TODO: Expand these tests when actuator mocking is more mature. +# Full actuator integration tests would require: +# - Mocking ActuatorBaseCfg and ActuatorBase classes +# - Testing implicit vs explicit actuator behavior +# - Testing stiffness/damping propagation +# Currently, we test the initialization behavior without actuators configured. +class TestProcessActuatorsCfg: + """Tests for _process_actuators_cfg method. + + Note: These tests focus on the initialization behavior when no actuators are configured. + Full actuator integration tests require additional mocking infrastructure. + """ + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_actuators_cfg_initializes_empty_dict(self, device: str): + """Test that _process_actuators_cfg initializes actuators as empty dict when none configured.""" + num_instances = 2 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # Ensure no actuators are configured + articulation.cfg.actuators = {} + + # Call _process_actuators_cfg + articulation._process_actuators_cfg() + + # Verify actuators dict is empty + assert articulation.actuators == {} + assert isinstance(articulation.actuators, dict) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_actuators_cfg_sets_implicit_flag_false(self, device: str): + """Test that _process_actuators_cfg sets _has_implicit_actuators to False initially.""" + num_instances = 2 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + articulation.cfg.actuators = {} + + # Call _process_actuators_cfg + articulation._process_actuators_cfg() + + # Verify flag is set to False + assert articulation._has_implicit_actuators is False + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_actuators_cfg_sets_joint_limit_gains(self, device: str): + """Test that _process_actuators_cfg sets joint_limit_ke and joint_limit_kd.""" + num_instances = 2 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + articulation.cfg.actuators = {} + + # Call _process_actuators_cfg + articulation._process_actuators_cfg() + + # Verify joint limit gains are set + joint_limit_ke = wp.to_torch(mock_view.get_attribute("joint_limit_ke", None)) + joint_limit_kd = wp.to_torch(mock_view.get_attribute("joint_limit_kd", None)) + + expected_ke = torch.full((num_instances, num_joints), 2500.0, device=device) + expected_kd = torch.full((num_instances, num_joints), 100.0, device=device) + + torch.testing.assert_close(joint_limit_ke, expected_ke, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(joint_limit_kd, expected_kd, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_actuators_cfg_warns_unactuated_joints(self, device: str): + """Test that _process_actuators_cfg warns when not all joints have actuators.""" + num_instances = 2 + num_joints = 4 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + device=device, + ) + + # No actuators configured but we have joints + articulation.cfg.actuators = {} + + # Should warn about unactuated joints + with pytest.warns(UserWarning, match="Not all actuators are configured"): + articulation._process_actuators_cfg() + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_process_actuators_cfg_no_warning_zero_joints(self, device: str): + """Test that _process_actuators_cfg does not warn when there are no joints.""" + num_instances = 2 + num_joints = 0 + num_bodies = 1 + articulation, mock_view, _ = create_test_articulation( + num_instances=num_instances, + num_joints=num_joints, + num_bodies=num_bodies, + device=device, + ) + + articulation.cfg.actuators = {} + + # Should not warn when there are no joints to actuate + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("error") + # This should not raise a warning + articulation._process_actuators_cfg() + + +## +# Main +## + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/source/isaaclab_newton/test/assets/rigid_object/test_rigid_object_data.py b/source/isaaclab_newton/test/assets/rigid_object/test_rigid_object_data.py new file mode 100644 index 00000000000..ba6bc7f5641 --- /dev/null +++ b/source/isaaclab_newton/test/assets/rigid_object/test_rigid_object_data.py @@ -0,0 +1,3292 @@ +# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for ArticulationData class comparing Newton implementation against PhysX reference.""" + +from __future__ import annotations + +import torch +from unittest.mock import MagicMock, patch + +import pytest +import warp as wp +from isaaclab_newton.assets.articulation.articulation_data import ArticulationData + +# TODO: Remove this import +from isaaclab.utils import math as math_utils + +# Import mock classes from shared module +from .mock_interface import MockNewtonArticulationView, MockNewtonModel + +# Initialize Warp +wp.init() + + +## +# Test Fixtures +## + + +@pytest.fixture +def mock_newton_manager(): + """Create mock NewtonManager with necessary methods.""" + mock_model = MockNewtonModel() + mock_state = MagicMock() + mock_control = MagicMock() + + # Patch where NewtonManager is used (in the articulation_data module) + with patch("isaaclab_newton.assets.articulation.articulation_data.NewtonManager") as MockManager: + MockManager.get_model.return_value = mock_model + MockManager.get_state_0.return_value = mock_state + MockManager.get_control.return_value = mock_control + MockManager.get_dt.return_value = 0.01 + yield MockManager + + +## +# Test Cases -- Defaults. +## + + +class TestDefaults: + """Tests the following properties: + - default_root_pose + - default_root_vel + - default_joint_pos + - default_joint_vel + + Runs the following checks: + - Checks that by default, the properties are all zero. + - Checks that the properties are settable. + - Checks that once the articulation data is primed, the properties cannot be changed. + """ + + def _setup_method(self, num_instances: int, num_dofs: int, device: str) -> ArticulationData: + mock_view = MockNewtonArticulationView(num_instances, 1, num_dofs, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + + return articulation_data + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_zero_instantiated(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test zero instantiated articulation data.""" + # Setup the articulation data + articulation_data = self._setup_method(num_instances, num_dofs, device) + # Check the types are correct + assert articulation_data.default_root_pose.dtype is wp.transformf + assert articulation_data.default_root_vel.dtype is wp.spatial_vectorf + assert articulation_data.default_joint_pos.dtype is wp.float32 + assert articulation_data.default_joint_vel.dtype is wp.float32 + # Check the shapes are correct + assert articulation_data.default_root_pose.shape == (num_instances,) + assert articulation_data.default_root_vel.shape == (num_instances,) + assert articulation_data.default_joint_pos.shape == (num_instances, num_dofs) + assert articulation_data.default_joint_vel.shape == (num_instances, num_dofs) + # Check the values are zero + assert torch.all( + wp.to_torch(articulation_data.default_root_pose) == torch.zeros(num_instances, 7, device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.default_root_vel) == torch.zeros(num_instances, 6, device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.default_joint_pos) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.default_joint_vel) == torch.zeros((num_instances, num_dofs), device=device) + ) + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_settable(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test that the articulation data is settable.""" + # Setup the articulation data + articulation_data = self._setup_method(num_instances, num_dofs, device) + # Set the default values + articulation_data.default_root_pose = wp.ones(num_instances, dtype=wp.transformf, device=device) + articulation_data.default_root_vel = wp.ones(num_instances, dtype=wp.spatial_vectorf, device=device) + articulation_data.default_joint_pos = wp.ones((num_instances, num_dofs), dtype=wp.float32, device=device) + articulation_data.default_joint_vel = wp.ones((num_instances, num_dofs), dtype=wp.float32, device=device) + # Check the types are correct + assert articulation_data.default_root_pose.dtype is wp.transformf + assert articulation_data.default_root_vel.dtype is wp.spatial_vectorf + assert articulation_data.default_joint_pos.dtype is wp.float32 + assert articulation_data.default_joint_vel.dtype is wp.float32 + # Check the shapes are correct + assert articulation_data.default_root_pose.shape == (num_instances,) + assert articulation_data.default_root_vel.shape == (num_instances,) + assert articulation_data.default_joint_pos.shape == (num_instances, num_dofs) + assert articulation_data.default_joint_vel.shape == (num_instances, num_dofs) + # Check the values are set + assert torch.all( + wp.to_torch(articulation_data.default_root_pose) == torch.ones(num_instances, 7, device=device) + ) + assert torch.all(wp.to_torch(articulation_data.default_root_vel) == torch.ones(num_instances, 6, device=device)) + assert torch.all( + wp.to_torch(articulation_data.default_joint_pos) == torch.ones((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.default_joint_vel) == torch.ones((num_instances, num_dofs), device=device) + ) + # Prime the articulation data + articulation_data.is_primed = True + # Check that the values cannot be changed + with pytest.raises(RuntimeError): + articulation_data.default_root_pose = wp.zeros(num_instances, dtype=wp.transformf, device=device) + with pytest.raises(RuntimeError): + articulation_data.default_root_vel = wp.zeros(num_instances, dtype=wp.spatial_vectorf, device=device) + with pytest.raises(RuntimeError): + articulation_data.default_joint_pos = wp.zeros((num_instances, num_dofs), dtype=wp.float32, device=device) + with pytest.raises(RuntimeError): + articulation_data.default_joint_vel = wp.zeros((num_instances, num_dofs), dtype=wp.float32, device=device) + + +## +# Test Cases -- Joint Commands (Set into the simulation). +## + + +class TestJointCommandsSetIntoSimulation: + """Tests the following properties: + - joint_pos_target + - joint_vel_target + - joint_effort_target + + Runs the following checks: + - Checks that their types and shapes are correct. + - Checks that the returned values are pointers to the internal data. + """ + + def _setup_method(self, num_instances: int, num_dofs: int, device: str) -> ArticulationData: + mock_view = MockNewtonArticulationView(num_instances, 1, num_dofs, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + + return articulation_data + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_initialized_to_zero(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test that the joint commands are initialized to zero.""" + # Setup the articulation data + articulation_data = self._setup_method(num_instances, num_dofs, device) + # Check the types is correct + assert articulation_data.joint_pos_target.dtype is wp.float32 + assert articulation_data.joint_vel_target.dtype is wp.float32 + assert articulation_data.joint_effort.dtype is wp.float32 + # Check the shape is correct + assert articulation_data.joint_pos_target.shape == (num_instances, num_dofs) + assert articulation_data.joint_vel_target.shape == (num_instances, num_dofs) + assert articulation_data.joint_effort.shape == (num_instances, num_dofs) + # Check the values are zero + assert torch.all( + wp.to_torch(articulation_data.joint_pos_target) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.joint_vel_target) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.joint_effort) == torch.zeros((num_instances, num_dofs), device=device) + ) + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_returns_reference(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test that the joint commands return a reference to the internal data.""" + # Setup the articulation data + articulation_data = self._setup_method(num_instances, num_dofs, device) + # Get the pointers + joint_pos_target = articulation_data.joint_pos_target + joint_vel_target = articulation_data.joint_vel_target + joint_effort = articulation_data.joint_effort + # Check that they are zeros + assert torch.all(wp.to_torch(joint_pos_target) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_vel_target) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_effort) == torch.zeros((num_instances, num_dofs), device=device)) + # Assign a different value to the internal data + articulation_data.joint_pos_target.fill_(1.0) + articulation_data.joint_vel_target.fill_(1.0) + articulation_data.joint_effort.fill_(1.0) + # Check that the joint commands return the new value + assert torch.all(wp.to_torch(joint_pos_target) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_vel_target) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_effort) == torch.ones((num_instances, num_dofs), device=device)) + # Assign a different value to the pointers + joint_pos_target.fill_(2.0) + joint_vel_target.fill_(2.0) + joint_effort.fill_(2.0) + # Check that the internal data has been updated + assert torch.all( + wp.to_torch(articulation_data.joint_pos_target) + == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + assert torch.all( + wp.to_torch(articulation_data.joint_vel_target) + == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + assert torch.all( + wp.to_torch(articulation_data.joint_effort) == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + + +## +# Test Cases -- Joint Commands (Explicit actuators). +## + + +class TestJointCommandsExplicitActuators: + """Tests the following properties: + - computed_effort + - applied_effort + - actuator_stiffness + - actuator_damping + - actuator_position_target + - actuator_velocity_target + - actuator_effort_target + + Runs the following checks: + - Checks that their types and shapes are correct. + - Checks that the returned values are pointers to the internal data. + """ + + def _setup_method(self, num_instances: int, num_dofs: int, device: str) -> ArticulationData: + mock_view = MockNewtonArticulationView(num_instances, 1, num_dofs, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + + return articulation_data + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_initialized_to_zero(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test that the explicit actuator properties are initialized to zero.""" + # Setup the articulation data + articulation_data = self._setup_method(num_instances, num_dofs, device) + # Check the types are correct + assert articulation_data.computed_effort.dtype is wp.float32 + assert articulation_data.applied_effort.dtype is wp.float32 + assert articulation_data.actuator_stiffness.dtype is wp.float32 + assert articulation_data.actuator_damping.dtype is wp.float32 + assert articulation_data.actuator_position_target.dtype is wp.float32 + assert articulation_data.actuator_velocity_target.dtype is wp.float32 + assert articulation_data.actuator_effort_target.dtype is wp.float32 + # Check the shapes are correct + assert articulation_data.computed_effort.shape == (num_instances, num_dofs) + assert articulation_data.applied_effort.shape == (num_instances, num_dofs) + assert articulation_data.actuator_stiffness.shape == (num_instances, num_dofs) + assert articulation_data.actuator_damping.shape == (num_instances, num_dofs) + assert articulation_data.actuator_position_target.shape == (num_instances, num_dofs) + assert articulation_data.actuator_velocity_target.shape == (num_instances, num_dofs) + assert articulation_data.actuator_effort_target.shape == (num_instances, num_dofs) + # Check the values are zero + assert torch.all( + wp.to_torch(articulation_data.computed_effort) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.applied_effort) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.actuator_stiffness) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.actuator_damping) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.actuator_position_target) + == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.actuator_velocity_target) + == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.actuator_effort_target) + == torch.zeros((num_instances, num_dofs), device=device) + ) + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_returns_reference(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test that the explicit actuator properties return a reference to the internal data.""" + # Setup the articulation data + articulation_data = self._setup_method(num_instances, num_dofs, device) + # Get the pointers + computed_effort = articulation_data.computed_effort + applied_effort = articulation_data.applied_effort + actuator_stiffness = articulation_data.actuator_stiffness + actuator_damping = articulation_data.actuator_damping + actuator_position_target = articulation_data.actuator_position_target + actuator_velocity_target = articulation_data.actuator_velocity_target + actuator_effort_target = articulation_data.actuator_effort_target + # Check that they are zeros + assert torch.all(wp.to_torch(computed_effort) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(applied_effort) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(actuator_stiffness) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(actuator_damping) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(actuator_position_target) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(actuator_velocity_target) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(actuator_effort_target) == torch.zeros((num_instances, num_dofs), device=device)) + # Assign a different value to the internal data + articulation_data.computed_effort.fill_(1.0) + articulation_data.applied_effort.fill_(1.0) + articulation_data.actuator_stiffness.fill_(1.0) + articulation_data.actuator_damping.fill_(1.0) + articulation_data.actuator_position_target.fill_(1.0) + articulation_data.actuator_velocity_target.fill_(1.0) + articulation_data.actuator_effort_target.fill_(1.0) + # Check that the properties return the new value + assert torch.all(wp.to_torch(computed_effort) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(applied_effort) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(actuator_stiffness) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(actuator_damping) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(actuator_position_target) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(actuator_velocity_target) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(actuator_effort_target) == torch.ones((num_instances, num_dofs), device=device)) + # Assign a different value to the pointers + computed_effort.fill_(2.0) + applied_effort.fill_(2.0) + actuator_stiffness.fill_(2.0) + actuator_damping.fill_(2.0) + actuator_position_target.fill_(2.0) + actuator_velocity_target.fill_(2.0) + actuator_effort_target.fill_(2.0) + # Check that the internal data has been updated + assert torch.all( + wp.to_torch(articulation_data.computed_effort) == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + assert torch.all( + wp.to_torch(articulation_data.applied_effort) == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + assert torch.all( + wp.to_torch(articulation_data.actuator_stiffness) + == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + assert torch.all( + wp.to_torch(articulation_data.actuator_damping) + == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + assert torch.all( + wp.to_torch(articulation_data.actuator_position_target) + == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + assert torch.all( + wp.to_torch(articulation_data.actuator_velocity_target) + == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + assert torch.all( + wp.to_torch(articulation_data.actuator_effort_target) + == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + + +## +# Test Cases -- Joint Properties (Set into Simulation). +## + + +class TestJointPropertiesSetIntoSimulation: + """Tests the following properties: + - joint_stiffness + - joint_damping + - joint_armature + - joint_friction_coeff + - joint_pos_limits_lower + - joint_pos_limits_upper + - joint_pos_limits (read-only, computed from lower and upper) + - joint_vel_limits + - joint_effort_limits + + Runs the following checks: + - Checks that their types and shapes are correct. + - Checks that the returned values are pointers to the internal data. + + .. note:: joint_pos_limits is read-only and does not change the joint position limits. + """ + + def _setup_method( + self, num_instances: int, num_dofs: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, num_dofs, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + + # return the mock view, so that it doesn't get garbage collected + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_initialized_to_zero(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test that the joint properties are initialized to zero (or ones for limits).""" + # Setup the articulation data + articulation_data, _ = self._setup_method(num_instances, num_dofs, device) + + # Check the types are correct + assert articulation_data.joint_stiffness.dtype is wp.float32 + assert articulation_data.joint_damping.dtype is wp.float32 + assert articulation_data.joint_armature.dtype is wp.float32 + assert articulation_data.joint_friction_coeff.dtype is wp.float32 + assert articulation_data.joint_pos_limits_lower.dtype is wp.float32 + assert articulation_data.joint_pos_limits_upper.dtype is wp.float32 + assert articulation_data.joint_pos_limits.dtype is wp.vec2f + assert articulation_data.joint_vel_limits.dtype is wp.float32 + assert articulation_data.joint_effort_limits.dtype is wp.float32 + + # Check the shapes are correct + assert articulation_data.joint_stiffness.shape == (num_instances, num_dofs) + assert articulation_data.joint_damping.shape == (num_instances, num_dofs) + assert articulation_data.joint_armature.shape == (num_instances, num_dofs) + assert articulation_data.joint_friction_coeff.shape == (num_instances, num_dofs) + assert articulation_data.joint_pos_limits_lower.shape == (num_instances, num_dofs) + assert articulation_data.joint_pos_limits_upper.shape == (num_instances, num_dofs) + assert articulation_data.joint_pos_limits.shape == (num_instances, num_dofs) + assert articulation_data.joint_vel_limits.shape == (num_instances, num_dofs) + assert articulation_data.joint_effort_limits.shape == (num_instances, num_dofs) + + # Check the values are zero + assert torch.all( + wp.to_torch(articulation_data.joint_stiffness) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.joint_damping) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.joint_armature) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.joint_friction_coeff) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.joint_pos_limits_lower) + == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.joint_pos_limits_upper) + == torch.zeros((num_instances, num_dofs), device=device) + ) + # joint_pos_limits should be (0, 0) for each joint since both lower and upper are 0 + joint_pos_limits = wp.to_torch(articulation_data.joint_pos_limits) + assert torch.all(joint_pos_limits == torch.zeros((num_instances, num_dofs, 2), device=device)) + # vel_limits and effort_limits are initialized to zeros in the mock + assert torch.all( + wp.to_torch(articulation_data.joint_vel_limits) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.joint_effort_limits) == torch.zeros((num_instances, num_dofs), device=device) + ) + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_returns_reference(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test that the joint properties return a reference to the internal data. + + Note: joint_pos_limits is read-only and always returns a new computed array. + """ + # Setup the articulation data + articulation_data, _ = self._setup_method(num_instances, num_dofs, device) + + # Get the pointers + joint_stiffness = articulation_data.joint_stiffness + joint_damping = articulation_data.joint_damping + joint_armature = articulation_data.joint_armature + joint_friction_coeff = articulation_data.joint_friction_coeff + joint_pos_limits_lower = articulation_data.joint_pos_limits_lower + joint_pos_limits_upper = articulation_data.joint_pos_limits_upper + joint_vel_limits = articulation_data.joint_vel_limits + joint_effort_limits = articulation_data.joint_effort_limits + + # Check that they have initial values (zeros or ones based on mock) + assert torch.all(wp.to_torch(joint_stiffness) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_damping) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_armature) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_friction_coeff) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_pos_limits_lower) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_pos_limits_upper) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_vel_limits) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_effort_limits) == torch.zeros((num_instances, num_dofs), device=device)) + + # Assign a different value to the internal data + articulation_data.joint_stiffness.fill_(1.0) + articulation_data.joint_damping.fill_(1.0) + articulation_data.joint_armature.fill_(1.0) + articulation_data.joint_friction_coeff.fill_(1.0) + articulation_data.joint_pos_limits_lower.fill_(-1.0) + articulation_data.joint_pos_limits_upper.fill_(1.0) + articulation_data.joint_vel_limits.fill_(2.0) + articulation_data.joint_effort_limits.fill_(2.0) + + # Check that the properties return the new value (reference behavior) + assert torch.all(wp.to_torch(joint_stiffness) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_damping) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_armature) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_friction_coeff) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all( + wp.to_torch(joint_pos_limits_lower) == torch.ones((num_instances, num_dofs), device=device) * -1.0 + ) + assert torch.all(wp.to_torch(joint_pos_limits_upper) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(joint_vel_limits) == torch.ones((num_instances, num_dofs), device=device) * 2.0) + assert torch.all(wp.to_torch(joint_effort_limits) == torch.ones((num_instances, num_dofs), device=device) * 2.0) + + # Check that joint_pos_limits is computed correctly from lower and upper + joint_pos_limits = wp.to_torch(articulation_data.joint_pos_limits) + expected_limits = torch.stack( + [ + torch.ones((num_instances, num_dofs), device=device) * -1.0, + torch.ones((num_instances, num_dofs), device=device), + ], + dim=-1, + ) + assert torch.all(joint_pos_limits == expected_limits) + + # Assign a different value to the pointers + joint_stiffness.fill_(3.0) + joint_damping.fill_(3.0) + joint_armature.fill_(3.0) + joint_friction_coeff.fill_(3.0) + joint_pos_limits_lower.fill_(-2.0) + joint_pos_limits_upper.fill_(2.0) + joint_vel_limits.fill_(4.0) + joint_effort_limits.fill_(4.0) + + # Check that the internal data has been updated + assert torch.all( + wp.to_torch(articulation_data.joint_stiffness) == torch.ones((num_instances, num_dofs), device=device) * 3.0 + ) + assert torch.all( + wp.to_torch(articulation_data.joint_damping) == torch.ones((num_instances, num_dofs), device=device) * 3.0 + ) + assert torch.all( + wp.to_torch(articulation_data.joint_armature) == torch.ones((num_instances, num_dofs), device=device) * 3.0 + ) + assert torch.all( + wp.to_torch(articulation_data.joint_friction_coeff) + == torch.ones((num_instances, num_dofs), device=device) * 3.0 + ) + assert torch.all( + wp.to_torch(articulation_data.joint_pos_limits_lower) + == torch.ones((num_instances, num_dofs), device=device) * -2.0 + ) + assert torch.all( + wp.to_torch(articulation_data.joint_pos_limits_upper) + == torch.ones((num_instances, num_dofs), device=device) * 2.0 + ) + assert torch.all( + wp.to_torch(articulation_data.joint_vel_limits) + == torch.ones((num_instances, num_dofs), device=device) * 4.0 + ) + assert torch.all( + wp.to_torch(articulation_data.joint_effort_limits) + == torch.ones((num_instances, num_dofs), device=device) * 4.0 + ) + + # Verify joint_pos_limits reflects the updated lower and upper values + joint_pos_limits_updated = wp.to_torch(articulation_data.joint_pos_limits) + expected_limits_updated = torch.stack( + [ + torch.ones((num_instances, num_dofs), device=device) * -2.0, + torch.ones((num_instances, num_dofs), device=device) * 2.0, + ], + dim=-1, + ) + assert torch.all(joint_pos_limits_updated == expected_limits_updated) + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_joint_pos_limits_is_read_only(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test that joint_pos_limits returns a new array each time (not a reference). + + Unlike other joint properties, joint_pos_limits is computed on-the-fly from + joint_pos_limits_lower and joint_pos_limits_upper. Modifying the returned array + should not affect the underlying data. + """ + # Setup the articulation data + articulation_data, _ = self._setup_method(num_instances, num_dofs, device) + + # Get joint_pos_limits twice + limits1 = articulation_data.joint_pos_limits + limits2 = articulation_data.joint_pos_limits + + # They should be separate arrays (not the same reference) + # Modifying one should not affect the other + limits1.fill_(2.0) + + # limits2 should be changed to 2.0 + assert torch.all(wp.to_torch(limits2) == torch.ones((num_instances, num_dofs, 2), device=device) * 2.0) + + # The underlying lower and upper should be unchanged + assert torch.all( + wp.to_torch(articulation_data.joint_pos_limits_lower) + == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.joint_pos_limits_upper) + == torch.zeros((num_instances, num_dofs), device=device) + ) + + +## +# Test Cases -- Joint Properties (Custom). +## + + +class TestJointPropertiesCustom: + """Tests the following properties: + - joint_dynamic_friction_coeff + - joint_viscous_friction_coeff + - soft_joint_pos_limits + - soft_joint_vel_limits + - gear_ratio + + Runs the following checks: + - Checks that their types and shapes are correct. + - Checks that the returned values are pointers to the internal data. + + .. note:: gear_ratio is initialized to ones (not zeros). + """ + + def _setup_method(self, num_instances: int, num_dofs: int, device: str) -> ArticulationData: + mock_view = MockNewtonArticulationView(num_instances, 1, num_dofs, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + + return articulation_data + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_initialized_correctly(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test that the custom joint properties are initialized correctly.""" + # Setup the articulation data + articulation_data = self._setup_method(num_instances, num_dofs, device) + + # Check the types are correct + assert articulation_data.joint_dynamic_friction_coeff.dtype is wp.float32 + assert articulation_data.joint_viscous_friction_coeff.dtype is wp.float32 + assert articulation_data.soft_joint_pos_limits.dtype is wp.vec2f + assert articulation_data.soft_joint_vel_limits.dtype is wp.float32 + assert articulation_data.gear_ratio.dtype is wp.float32 + + # Check the shapes are correct + assert articulation_data.joint_dynamic_friction_coeff.shape == (num_instances, num_dofs) + assert articulation_data.joint_viscous_friction_coeff.shape == (num_instances, num_dofs) + assert articulation_data.soft_joint_pos_limits.shape == (num_instances, num_dofs) + assert articulation_data.soft_joint_vel_limits.shape == (num_instances, num_dofs) + assert articulation_data.gear_ratio.shape == (num_instances, num_dofs) + + # Check the values are initialized correctly + # Most are zeros, but gear_ratio is initialized to ones + assert torch.all( + wp.to_torch(articulation_data.joint_dynamic_friction_coeff) + == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.joint_viscous_friction_coeff) + == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.soft_joint_pos_limits) + == torch.zeros((num_instances, num_dofs, 2), device=device) + ) + assert torch.all( + wp.to_torch(articulation_data.soft_joint_vel_limits) + == torch.zeros((num_instances, num_dofs), device=device) + ) + # gear_ratio is initialized to ones + assert torch.all( + wp.to_torch(articulation_data.gear_ratio) == torch.ones((num_instances, num_dofs), device=device) + ) + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_dofs", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_returns_reference(self, mock_newton_manager, num_instances: int, num_dofs: int, device: str): + """Test that the custom joint properties return a reference to the internal data.""" + # Setup the articulation data + articulation_data = self._setup_method(num_instances, num_dofs, device) + + # Get the pointers + joint_dynamic_friction_coeff = articulation_data.joint_dynamic_friction_coeff + joint_viscous_friction_coeff = articulation_data.joint_viscous_friction_coeff + soft_joint_pos_limits = articulation_data.soft_joint_pos_limits + soft_joint_vel_limits = articulation_data.soft_joint_vel_limits + gear_ratio = articulation_data.gear_ratio + + # Check that they have initial values + assert torch.all( + wp.to_torch(joint_dynamic_friction_coeff) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(joint_viscous_friction_coeff) == torch.zeros((num_instances, num_dofs), device=device) + ) + assert torch.all(wp.to_torch(soft_joint_pos_limits) == torch.zeros((num_instances, num_dofs, 2), device=device)) + assert torch.all(wp.to_torch(soft_joint_vel_limits) == torch.zeros((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(gear_ratio) == torch.ones((num_instances, num_dofs), device=device)) + + # Assign a different value to the internal data + articulation_data.joint_dynamic_friction_coeff.fill_(1.0) + articulation_data.joint_viscous_friction_coeff.fill_(1.0) + articulation_data.soft_joint_pos_limits.fill_(1.0) + articulation_data.soft_joint_vel_limits.fill_(1.0) + articulation_data.gear_ratio.fill_(2.0) + + # Check that the properties return the new value (reference behavior) + assert torch.all( + wp.to_torch(joint_dynamic_friction_coeff) == torch.ones((num_instances, num_dofs), device=device) + ) + assert torch.all( + wp.to_torch(joint_viscous_friction_coeff) == torch.ones((num_instances, num_dofs), device=device) + ) + assert torch.all(wp.to_torch(soft_joint_pos_limits) == torch.ones((num_instances, num_dofs, 2), device=device)) + assert torch.all(wp.to_torch(soft_joint_vel_limits) == torch.ones((num_instances, num_dofs), device=device)) + assert torch.all(wp.to_torch(gear_ratio) == torch.ones((num_instances, num_dofs), device=device) * 2.0) + + # Assign a different value to the pointers + joint_dynamic_friction_coeff.fill_(3.0) + joint_viscous_friction_coeff.fill_(3.0) + soft_joint_pos_limits.fill_(3.0) + soft_joint_vel_limits.fill_(3.0) + gear_ratio.fill_(4.0) + + # Check that the internal data has been updated + assert torch.all( + wp.to_torch(articulation_data.joint_dynamic_friction_coeff) + == torch.ones((num_instances, num_dofs), device=device) * 3.0 + ) + assert torch.all( + wp.to_torch(articulation_data.joint_viscous_friction_coeff) + == torch.ones((num_instances, num_dofs), device=device) * 3.0 + ) + assert torch.all( + wp.to_torch(articulation_data.soft_joint_pos_limits) + == torch.ones((num_instances, num_dofs, 2), device=device) * 3.0 + ) + assert torch.all( + wp.to_torch(articulation_data.soft_joint_vel_limits) + == torch.ones((num_instances, num_dofs), device=device) * 3.0 + ) + assert torch.all( + wp.to_torch(articulation_data.gear_ratio) == torch.ones((num_instances, num_dofs), device=device) * 4.0 + ) + + +## +# Test Cases -- Fixed Tendon Properties. +## + + +# TODO: Update these tests when fixed tendon support is added to Newton. +class TestFixedTendonProperties: + """Tests the following properties: + - fixed_tendon_stiffness + - fixed_tendon_damping + - fixed_tendon_limit_stiffness + - fixed_tendon_rest_length + - fixed_tendon_offset + - fixed_tendon_pos_limits + + Currently, all these properties raise NotImplementedError as fixed tendons + are not supported in Newton. + + Runs the following checks: + - Checks that all properties raise NotImplementedError. + """ + + def _setup_method(self, num_instances: int, num_dofs: int, device: str) -> ArticulationData: + mock_view = MockNewtonArticulationView(num_instances, 1, num_dofs, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + + return articulation_data + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_all_fixed_tendon_properties_not_implemented(self, mock_newton_manager, device: str): + """Test that all fixed tendon properties raise NotImplementedError.""" + articulation_data = self._setup_method(1, 1, device) + + with pytest.raises(NotImplementedError): + _ = articulation_data.fixed_tendon_stiffness + with pytest.raises(NotImplementedError): + _ = articulation_data.fixed_tendon_damping + with pytest.raises(NotImplementedError): + _ = articulation_data.fixed_tendon_limit_stiffness + with pytest.raises(NotImplementedError): + _ = articulation_data.fixed_tendon_rest_length + with pytest.raises(NotImplementedError): + _ = articulation_data.fixed_tendon_offset + with pytest.raises(NotImplementedError): + _ = articulation_data.fixed_tendon_pos_limits + + +## +# Test Cases -- Spatial Tendon Properties. +## + + +# TODO: Update these tests when spatial tendon support is added to Newton. +class TestSpatialTendonProperties: + """Tests the following properties: + - spatial_tendon_stiffness + - spatial_tendon_damping + - spatial_tendon_limit_stiffness + - spatial_tendon_offset + + Currently, all these properties raise NotImplementedError as spatial tendons + are not supported in Newton. + + Runs the following checks: + - Checks that all properties raise NotImplementedError. + """ + + def _setup_method(self, num_instances: int, num_dofs: int, device: str) -> ArticulationData: + mock_view = MockNewtonArticulationView(num_instances, 1, num_dofs, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + + return articulation_data + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_all_spatial_tendon_properties_not_implemented(self, mock_newton_manager, device: str): + """Test that all spatial tendon properties raise NotImplementedError.""" + articulation_data = self._setup_method(1, 1, device) + + with pytest.raises(NotImplementedError): + _ = articulation_data.spatial_tendon_stiffness + with pytest.raises(NotImplementedError): + _ = articulation_data.spatial_tendon_damping + with pytest.raises(NotImplementedError): + _ = articulation_data.spatial_tendon_limit_stiffness + with pytest.raises(NotImplementedError): + _ = articulation_data.spatial_tendon_offset + + +## +# Test Cases -- Root state properties. +## + + +class TestRootLinkPoseW: + """Tests the root link pose property + + This value is read from the simulation. There is no math to check for. + + Runs the following checks: + - Checks that the returned value is a pointer to the internal data. + - Checks that the returned value is correct. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_root_link_pose_w(self, mock_newton_manager, num_instances: int, device: str): + """Test that the root link pose property returns a pointer to the internal data.""" + articulation_data, _ = self._setup_method(num_instances, device) + + # Check the type and shape + assert articulation_data.root_link_pose_w.shape == (num_instances,) + assert articulation_data.root_link_pose_w.dtype == wp.transformf + + # Mock data is initialized to zeros + assert torch.all(wp.to_torch(articulation_data.root_link_pose_w) == torch.zeros((1, 7), device=device)) + + # Get the property + root_link_pose_w = articulation_data.root_link_pose_w + + # Assign a different value to the internal data + articulation_data.root_link_pose_w.fill_(1.0) + + # Check that the property returns the new value (reference behavior) + assert torch.all(wp.to_torch(articulation_data.root_link_pose_w) == torch.ones((1, 7), device=device)) + + # Assign a different value to the pointers + root_link_pose_w.fill_(2.0) + + # Check that the internal data has been updated + assert torch.all(wp.to_torch(articulation_data.root_link_pose_w) == torch.ones((1, 7), device=device) * 2.0) + + +class TestRootLinkVelW: + """Tests the root link velocity property + + This value is derived from the root center of mass velocity. To ensure that the value is correctly computed, + we will compare the calculated value to the one currently calculated in the version 2.3.1 of IsaacLab. + + Runs the following checks: + - Checks that the returned value is a pointer to the internal data. + - Checks that the returned value is correct. + - Checks that the timestamp is updated correctly. + - Checks that the data is invalidated when the timestamp is updated. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_correctness(self, mock_newton_manager, num_instances: int, device: str): + """Test that the root link velocity property is correctly computed.""" + articulation_data, mock_view = self._setup_method(num_instances, device) + + # Check the type and shape + assert articulation_data.root_link_vel_w.shape == (num_instances,) + assert articulation_data.root_link_vel_w.dtype == wp.spatial_vectorf + + # Mock data is initialized to zeros + assert torch.all( + wp.to_torch(articulation_data.root_link_vel_w) == torch.zeros((num_instances, 6), device=device) + ) + + for i in range(10): + articulation_data._sim_timestamp = i + 1.0 + # Generate random com velocity and body com position + com_vel = torch.rand((num_instances, 6), device=device) + body_com_pos = torch.rand((num_instances, 1, 3), device=device) + root_link_pose = torch.zeros((num_instances, 7), device=device) + root_link_pose[:, 3:] = torch.randn((num_instances, 4), device=device) + root_link_pose[:, 3:] = torch.nn.functional.normalize(root_link_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_link_pose, dtype=wp.transformf), + root_velocities=wp.from_torch(com_vel, dtype=wp.spatial_vectorf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # Use the original IsaacLab code to compute the root link velocities + vel = com_vel.clone() + # TODO: Move the function from math_utils to a test utils file. Decoupling it from changes in math_utils. + vel[:, :3] += torch.linalg.cross( + vel[:, 3:], math_utils.quat_apply(root_link_pose[:, 3:], -body_com_pos[:, 0]), dim=-1 + ) + + # Compare the computed value to the one from the articulation data + assert torch.allclose(wp.to_torch(articulation_data.root_link_vel_w), vel, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_update_timestamp(self, mock_newton_manager, device: str): + """Test that the timestamp is updated correctly.""" + articulation_data, mock_view = self._setup_method(1, device) + + # Check that the timestamp is initialized to -1.0 + assert articulation_data._root_link_vel_w.timestamp == -1.0 + + # Check that the data class timestamp is initialized to 0.0 + assert articulation_data._sim_timestamp == 0.0 + + # Request the property + value = wp.to_torch(articulation_data.root_link_vel_w).clone() + + # Check that the timestamp is updated. The timestamp should be the same as the data class timestamp. + assert articulation_data._root_link_vel_w.timestamp == articulation_data._sim_timestamp + + # Update the root_com_vel_w + mock_view.set_mock_data( + root_velocities=wp.from_torch(torch.rand((1, 6), device=device), dtype=wp.spatial_vectorf), + ) + + # Check that the property value was not updated + assert torch.all(wp.to_torch(articulation_data.root_link_vel_w) == value) + + # Update the data class timestamp + articulation_data._sim_timestamp = 1.0 + + # Check that the property timestamp was not updated + assert articulation_data._root_link_vel_w.timestamp != articulation_data._sim_timestamp + + # Check that the property value was updated + assert torch.all(wp.to_torch(articulation_data.root_link_vel_w) != value) + + +class TestRootComPoseW: + """Tests the root center of mass pose property + + This value is derived from the root link pose and the body com position. To ensure that the value is correctly computed, + we will compare the calculated value to the one currently calculated in the version 2.3.1 of IsaacLab. + + Runs the following checks: + - Checks that the returned value is a pointer to the internal data. + - Checks that the returned value is correct. + - Checks that the timestamp is updated correctly. + - Checks that the data is invalidated when the timestamp is updated. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_root_com_pose_w(self, mock_newton_manager, num_instances: int, device: str): + """Test that the root center of mass pose property returns a pointer to the internal data.""" + articulation_data, mock_view = self._setup_method(num_instances, device) + + # Check the type and shape + assert articulation_data.root_com_pose_w.shape == (num_instances,) + assert articulation_data.root_com_pose_w.dtype == wp.transformf + + # Mock data is initialized to zeros + assert torch.all( + wp.to_torch(articulation_data.root_com_pose_w) == torch.zeros((num_instances, 7), device=device) + ) + + for i in range(10): + articulation_data._sim_timestamp = i + 1.0 + # Generate random root link pose and body com position + root_link_pose = torch.zeros((num_instances, 7), device=device) + root_link_pose[:, :3] = torch.rand((num_instances, 3), device=device) + root_link_pose[:, 3:] = torch.randn((num_instances, 4), device=device) + root_link_pose[:, 3:] = torch.nn.functional.normalize(root_link_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + body_com_pos = torch.rand((num_instances, 1, 3), device=device) + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_link_pose, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # Use the original IsaacLab code to compute the root center of mass pose + root_link_pos_w = root_link_pose[:, :3] + root_link_quat_w = root_link_pose[:, 3:] + body_com_pos_b = body_com_pos.clone() + body_com_quat_b = torch.zeros((num_instances, 1, 4), device=device) + body_com_quat_b[:, :, 3] = 1.0 + # --- IL 2.3.1 code --- + pos, quat = math_utils.combine_frame_transforms( + root_link_pos_w, root_link_quat_w, body_com_pos_b[:, 0], body_com_quat_b[:, 0] + ) + # --- + root_com_pose = torch.cat((pos, quat), dim=-1) + + # Compare the computed value to the one from the articulation data + assert torch.allclose(wp.to_torch(articulation_data.root_com_pose_w), root_com_pose, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_update_timestamp(self, mock_newton_manager, device: str): + """Test that the timestamp is updated correctly.""" + articulation_data, mock_view = self._setup_method(1, device) + + # Check that the timestamp is initialized to -1.0 + assert articulation_data._root_com_pose_w.timestamp == -1.0 + + # Check that the data class timestamp is initialized to 0.0 + assert articulation_data._sim_timestamp == 0.0 + + # Request the property + value = wp.to_torch(articulation_data.root_com_pose_w).clone() + + # Check that the timestamp is updated. The timestamp should be the same as the data class timestamp. + assert articulation_data._root_com_pose_w.timestamp == articulation_data._sim_timestamp + + # Update the root_com_vel_w + mock_view.set_mock_data( + root_transforms=wp.from_torch(torch.rand((1, 7), device=device), dtype=wp.transformf), + ) + + # Check that the property value was not updated + assert torch.all(wp.to_torch(articulation_data.root_com_pose_w) == value) + + # Update the data class timestamp + articulation_data._sim_timestamp = 1.0 + + # Check that the property timestamp was not updated + assert articulation_data._root_com_pose_w.timestamp != articulation_data._sim_timestamp + + # Check that the property value was updated + assert torch.all(wp.to_torch(articulation_data.root_com_pose_w) != value) + + +class TestRootComVelW: + """Tests the root center of mass velocity property + + This value is read from the simulation. There is no math to check for. + + Checks that the returned value is a pointer to the internal data. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_root_com_vel_w(self, mock_newton_manager, num_instances: int, device: str): + """Test that the root center of mass velocity property returns a pointer to the internal data.""" + articulation_data, _ = self._setup_method(num_instances, device) + + # Check the type and shape + assert articulation_data.root_com_vel_w.shape == (num_instances,) + assert articulation_data.root_com_vel_w.dtype == wp.spatial_vectorf + + # Mock data is initialized to zeros + assert torch.all( + wp.to_torch(articulation_data.root_com_vel_w) == torch.zeros((num_instances, 6), device=device) + ) + + # Get the property + root_com_vel_w = articulation_data.root_com_vel_w + + # Assign a different value to the internal data + articulation_data.root_com_vel_w.fill_(1.0) + + # Check that the property returns the new value (reference behavior) + assert torch.all(wp.to_torch(articulation_data.root_com_vel_w) == torch.ones((num_instances, 6), device=device)) + + # Assign a different value to the pointers + root_com_vel_w.fill_(2.0) + + # Check that the internal data has been updated + assert torch.all( + wp.to_torch(articulation_data.root_com_vel_w) == torch.ones((num_instances, 6), device=device) * 2.0 + ) + + +class TestRootState: + """Tests the root state properties + + Test the root state properties are correctly updated from the pose and velocity properties. + Tests the following properties: + - root_state_w + - root_link_state_w + - root_com_state_w + + For each property, we run the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is correctly assembled from pose and velocity. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_all_root_state_properties(self, mock_newton_manager, num_instances: int, device: str): + """Test that all root state properties correctly combine pose and velocity.""" + articulation_data, mock_view = self._setup_method(num_instances, device) + + # Generate random mock data + for i in range(5): + articulation_data._sim_timestamp = i + 1.0 + + # Generate random root link pose + root_link_pose = torch.zeros((num_instances, 7), device=device) + root_link_pose[:, :3] = torch.rand((num_instances, 3), device=device) + root_link_pose[:, 3:] = torch.randn((num_instances, 4), device=device) + root_link_pose[:, 3:] = torch.nn.functional.normalize(root_link_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + + # Generate random velocities and com position + com_vel = torch.rand((num_instances, 6), device=device) + body_com_pos = torch.rand((num_instances, 1, 3), device=device) + + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_link_pose, dtype=wp.transformf), + root_velocities=wp.from_torch(com_vel, dtype=wp.spatial_vectorf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # --- Test root_state_w --- + # Combines root_link_pose_w with root_com_vel_w + root_state = wp.to_torch(articulation_data.root_state_w) + expected_root_state = torch.cat([root_link_pose, com_vel], dim=-1) + + assert root_state.shape == (num_instances, 13) + assert torch.allclose(root_state, expected_root_state, atol=1e-6, rtol=1e-6) + + # --- Test root_link_state_w --- + # Combines root_link_pose_w with root_link_vel_w + root_link_state = wp.to_torch(articulation_data.root_link_state_w) + + # Compute expected root_link_vel from com_vel (same as TestRootLinkVelW) + root_link_vel = com_vel.clone() + root_link_vel[:, :3] += torch.linalg.cross( + root_link_vel[:, 3:], math_utils.quat_apply(root_link_pose[:, 3:], -body_com_pos[:, 0]), dim=-1 + ) + expected_root_link_state = torch.cat([root_link_pose, root_link_vel], dim=-1) + + assert root_link_state.shape == (num_instances, 13) + assert torch.allclose(root_link_state, expected_root_link_state, atol=1e-6, rtol=1e-6) + + # --- Test root_com_state_w --- + # Combines root_com_pose_w with root_com_vel_w + root_com_state = wp.to_torch(articulation_data.root_com_state_w) + + # Compute expected root_com_pose from root_link_pose and body_com_pos (same as TestRootComPoseW) + body_com_quat_b = torch.zeros((num_instances, 4), device=device) + body_com_quat_b[:, 3] = 1.0 + root_com_pos, root_com_quat = math_utils.combine_frame_transforms( + root_link_pose[:, :3], root_link_pose[:, 3:], body_com_pos[:, 0], body_com_quat_b + ) + expected_root_com_state = torch.cat([root_com_pos, root_com_quat, com_vel], dim=-1) + + assert root_com_state.shape == (num_instances, 13) + assert torch.allclose(root_com_state, expected_root_com_state, atol=1e-6, rtol=1e-6) + + +## +# Test Cases -- Body state properties. +## + + +class TestBodyMassInertia: + """Tests the body mass and inertia properties. + + These values are read directly from the simulation bindings. + + Tests the following properties: + - body_mass + - body_inertia + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is a reference to the internal data. + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_body_mass_and_inertia(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that body_mass and body_inertia have correct types, shapes, and reference behavior.""" + articulation_data, mock_view = self._setup_method(num_instances, num_bodies, device) + + # --- Test body_mass --- + # Check the type and shape + assert articulation_data.body_mass.shape == (num_instances, num_bodies) + assert articulation_data.body_mass.dtype == wp.float32 + + # Mock data initializes body_mass to ones + assert torch.all( + wp.to_torch(articulation_data.body_mass) == torch.zeros((num_instances, num_bodies), device=device) + ) + + # Get the property reference + body_mass_ref = articulation_data.body_mass + + # Assign a different value to the internal data via property + articulation_data.body_mass.fill_(2.0) + + # Check that the property returns the new value (reference behavior) + assert torch.all( + wp.to_torch(articulation_data.body_mass) == torch.ones((num_instances, num_bodies), device=device) * 2.0 + ) + + # Assign a different value via reference + body_mass_ref.fill_(3.0) + + # Check that the internal data has been updated + assert torch.all( + wp.to_torch(articulation_data.body_mass) == torch.ones((num_instances, num_bodies), device=device) * 3.0 + ) + + # --- Test body_inertia --- + # Check the type and shape + assert articulation_data.body_inertia.shape == (num_instances, num_bodies) + assert articulation_data.body_inertia.dtype == wp.mat33f + + # Mock data initializes body_inertia to zeros + expected_inertia = torch.zeros((num_instances, num_bodies, 3, 3), device=device) + assert torch.all(wp.to_torch(articulation_data.body_inertia) == expected_inertia) + + # Get the property reference + body_inertia_ref = articulation_data.body_inertia + + # Assign a different value to the internal data via property + articulation_data.body_inertia.fill_(1.0) + + # Check that the property returns the new value (reference behavior) + expected_inertia_ones = torch.ones((num_instances, num_bodies, 3, 3), device=device) + assert torch.all(wp.to_torch(articulation_data.body_inertia) == expected_inertia_ones) + + # Assign a different value via reference + body_inertia_ref.fill_(2.0) + + # Check that the internal data has been updated + expected_inertia_twos = torch.ones((num_instances, num_bodies, 3, 3), device=device) * 2.0 + assert torch.all(wp.to_torch(articulation_data.body_inertia) == expected_inertia_twos) + + +class TestBodyLinkPoseW: + """Tests the body link pose property. + + This value is read directly from the simulation bindings. + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is a reference to the internal data. + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_body_link_pose_w(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that body_link_pose_w has correct type, shape, and reference behavior.""" + articulation_data, _ = self._setup_method(num_instances, num_bodies, device) + + # Check the type and shape + assert articulation_data.body_link_pose_w.shape == (num_instances, num_bodies) + assert articulation_data.body_link_pose_w.dtype == wp.transformf + + # Mock data is initialized to zeros + expected = torch.zeros((num_instances, num_bodies, 7), device=device) + assert torch.all(wp.to_torch(articulation_data.body_link_pose_w) == expected) + + # Get the property reference + body_link_pose_ref = articulation_data.body_link_pose_w + + # Assign a different value via property + articulation_data.body_link_pose_w.fill_(1.0) + + # Check that the property returns the new value (reference behavior) + expected_ones = torch.ones((num_instances, num_bodies, 7), device=device) + assert torch.all(wp.to_torch(articulation_data.body_link_pose_w) == expected_ones) + + # Assign a different value via reference + body_link_pose_ref.fill_(2.0) + + # Check that the internal data has been updated + expected_twos = torch.ones((num_instances, num_bodies, 7), device=device) * 2.0 + assert torch.all(wp.to_torch(articulation_data.body_link_pose_w) == expected_twos) + + +class TestBodyLinkVelW: + """Tests the body link velocity property. + + This value is derived from body COM velocity. To ensure correctness, + we compare against the reference implementation. + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is correctly computed. + - Checks that the timestamp is updated correctly. + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_correctness(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that body_link_vel_w is correctly computed from COM velocity.""" + articulation_data, mock_view = self._setup_method(num_instances, num_bodies, device) + + # Check the type and shape + assert articulation_data.body_link_vel_w.shape == (num_instances, num_bodies) + assert articulation_data.body_link_vel_w.dtype == wp.spatial_vectorf + + # Mock data is initialized to zeros + expected = torch.zeros((num_instances, num_bodies, 6), device=device) + assert torch.all(wp.to_torch(articulation_data.body_link_vel_w) == expected) + + for i in range(5): + articulation_data._sim_timestamp = i + 1.0 + + # Generate random COM velocity and body COM position + com_vel = torch.rand((num_instances, num_bodies, 6), device=device) + body_com_pos = torch.rand((num_instances, num_bodies, 3), device=device) + + # Generate random link poses with normalized quaternions + link_pose = torch.zeros((num_instances, num_bodies, 7), device=device) + link_pose[..., :3] = torch.rand((num_instances, num_bodies, 3), device=device) + link_pose[..., 3:] = torch.randn((num_instances, num_bodies, 4), device=device) + link_pose[..., 3:] = torch.nn.functional.normalize(link_pose[..., 3:], p=2.0, dim=-1, eps=1e-12) + + mock_view.set_mock_data( + link_transforms=wp.from_torch(link_pose, dtype=wp.transformf), + link_velocities=wp.from_torch(com_vel, dtype=wp.spatial_vectorf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # Compute expected link velocity using IsaacLab reference implementation + # vel[:, :3] += cross(vel[:, 3:], quat_apply(quat, -body_com_pos)) + expected_vel = com_vel.clone() + expected_vel[..., :3] += torch.linalg.cross( + expected_vel[..., 3:], + math_utils.quat_apply(link_pose[..., 3:], -body_com_pos), + dim=-1, + ) + + # Compare the computed value + assert torch.allclose(wp.to_torch(articulation_data.body_link_vel_w), expected_vel, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_timestamp_invalidation(self, mock_newton_manager, device: str): + """Test that data is invalidated when timestamp is updated.""" + articulation_data, mock_view = self._setup_method(1, 1, device) + + # Check initial timestamp + assert articulation_data._body_link_vel_w.timestamp == -1.0 + assert articulation_data._sim_timestamp == 0.0 + + # Request the property to trigger computation + value = wp.to_torch(articulation_data.body_link_vel_w).clone() + + # Check that buffer timestamp matches sim timestamp + assert articulation_data._body_link_vel_w.timestamp == articulation_data._sim_timestamp + + # Update mock data without changing sim timestamp + mock_view.set_mock_data( + link_velocities=wp.from_torch(torch.rand((1, 1, 6), device=device), dtype=wp.spatial_vectorf), + ) + + # Value should NOT change (cached) + assert torch.all(wp.to_torch(articulation_data.body_link_vel_w) == value) + + # Update sim timestamp + articulation_data._sim_timestamp = 1.0 + + # Buffer timestamp should now be stale + assert articulation_data._body_link_vel_w.timestamp != articulation_data._sim_timestamp + + # Value should now be recomputed (different from cached) + assert not torch.all(wp.to_torch(articulation_data.body_link_vel_w) == value) + + +class TestBodyComPoseW: + """Tests the body center of mass pose property. + + This value is derived from body link pose and body COM position. + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is correctly computed. + - Checks that the timestamp is updated correctly. + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_correctness(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that body_com_pose_w is correctly computed from link pose and COM position.""" + articulation_data, mock_view = self._setup_method(num_instances, num_bodies, device) + + # Check the type and shape + assert articulation_data.body_com_pose_w.shape == (num_instances, num_bodies) + assert articulation_data.body_com_pose_w.dtype == wp.transformf + + # Mock data is initialized to zeros + expected = torch.zeros((num_instances, num_bodies, 7), device=device) + assert torch.all(wp.to_torch(articulation_data.body_com_pose_w) == expected) + + for i in range(5): + articulation_data._sim_timestamp = i + 1.0 + + # Generate random link poses with normalized quaternions + link_pose = torch.zeros((num_instances, num_bodies, 7), device=device) + link_pose[..., :3] = torch.rand((num_instances, num_bodies, 3), device=device) + link_pose[..., 3:] = torch.randn((num_instances, num_bodies, 4), device=device) + link_pose[..., 3:] = torch.nn.functional.normalize(link_pose[..., 3:], p=2.0, dim=-1, eps=1e-12) + + # Generate random body COM position in body frame + body_com_pos = torch.rand((num_instances, num_bodies, 3), device=device) + + mock_view.set_mock_data( + link_transforms=wp.from_torch(link_pose, dtype=wp.transformf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # Compute expected COM pose using IsaacLab reference implementation + # combine_frame_transforms(link_pos, link_quat, com_pos_b, identity_quat) + body_com_quat_b = torch.zeros((num_instances, num_bodies, 4), device=device) + body_com_quat_b[..., 3] = 1.0 # identity quaternion + + expected_pos, expected_quat = math_utils.combine_frame_transforms( + link_pose[..., :3], link_pose[..., 3:], body_com_pos, body_com_quat_b + ) + expected_pose = torch.cat([expected_pos, expected_quat], dim=-1) + + # Compare the computed value + assert torch.allclose(wp.to_torch(articulation_data.body_com_pose_w), expected_pose, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_timestamp_invalidation(self, mock_newton_manager, device: str): + """Test that data is invalidated when timestamp is updated.""" + articulation_data, mock_view = self._setup_method(1, 1, device) + + # Check initial timestamp + assert articulation_data._body_com_pose_w.timestamp == -1.0 + assert articulation_data._sim_timestamp == 0.0 + + # Request the property to trigger computation + value = wp.to_torch(articulation_data.body_com_pose_w).clone() + + # Check that buffer timestamp matches sim timestamp + assert articulation_data._body_com_pose_w.timestamp == articulation_data._sim_timestamp + + # Update mock data without changing sim timestamp + mock_view.set_mock_data( + link_transforms=wp.from_torch(torch.rand((1, 1, 7), device=device), dtype=wp.transformf), + ) + + # Value should NOT change (cached) + assert torch.all(wp.to_torch(articulation_data.body_com_pose_w) == value) + + # Update sim timestamp + articulation_data._sim_timestamp = 1.0 + + # Buffer timestamp should now be stale + assert articulation_data._body_com_pose_w.timestamp != articulation_data._sim_timestamp + + # Value should now be recomputed (different from cached) + assert not torch.all(wp.to_torch(articulation_data.body_com_pose_w) == value) + + +class TestBodyComVelW: + """Tests the body center of mass velocity property. + + This value is read directly from the simulation bindings. + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is a reference to the internal data. + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_body_com_vel_w(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that body_com_vel_w has correct type, shape, and reference behavior.""" + articulation_data, _ = self._setup_method(num_instances, num_bodies, device) + + # Check the type and shape + assert articulation_data.body_com_vel_w.shape == (num_instances, num_bodies) + assert articulation_data.body_com_vel_w.dtype == wp.spatial_vectorf + + # Mock data is initialized to zeros + expected = torch.zeros((num_instances, num_bodies, 6), device=device) + assert torch.all(wp.to_torch(articulation_data.body_com_vel_w) == expected) + + # Get the property reference + body_com_vel_ref = articulation_data.body_com_vel_w + + # Assign a different value via property + articulation_data.body_com_vel_w.fill_(1.0) + + # Check that the property returns the new value (reference behavior) + expected_ones = torch.ones((num_instances, num_bodies, 6), device=device) + assert torch.all(wp.to_torch(articulation_data.body_com_vel_w) == expected_ones) + + # Assign a different value via reference + body_com_vel_ref.fill_(2.0) + + # Check that the internal data has been updated + expected_twos = torch.ones((num_instances, num_bodies, 6), device=device) * 2.0 + assert torch.all(wp.to_torch(articulation_data.body_com_vel_w) == expected_twos) + + +class TestBodyState: + """Tests the body state properties. + + Test the body state properties are correctly updated from the pose and velocity properties. + Tests the following properties: + - body_state_w + - body_link_state_w + - body_com_state_w + + For each property, we run the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is correctly assembled from pose and velocity. + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_all_body_state_properties(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that all body state properties correctly combine pose and velocity.""" + articulation_data, mock_view = self._setup_method(num_instances, num_bodies, device) + + # Generate random mock data + for i in range(5): + articulation_data._sim_timestamp = i + 1.0 + + # Generate random body link pose with normalized quaternions + body_link_pose = torch.zeros((num_instances, num_bodies, 7), device=device) + body_link_pose[..., :3] = torch.rand((num_instances, num_bodies, 3), device=device) + body_link_pose[..., 3:] = torch.randn((num_instances, num_bodies, 4), device=device) + body_link_pose[..., 3:] = torch.nn.functional.normalize(body_link_pose[..., 3:], p=2.0, dim=-1, eps=1e-12) + + # Generate random COM velocities and COM position + com_vel = torch.rand((num_instances, num_bodies, 6), device=device) + body_com_pos = torch.rand((num_instances, num_bodies, 3), device=device) + + mock_view.set_mock_data( + link_transforms=wp.from_torch(body_link_pose, dtype=wp.transformf), + link_velocities=wp.from_torch(com_vel, dtype=wp.spatial_vectorf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # --- Test body_state_w --- + # Combines body_link_pose_w with body_com_vel_w + body_state = wp.to_torch(articulation_data.body_state_w) + expected_body_state = torch.cat([body_link_pose, com_vel], dim=-1) + + assert body_state.shape == (num_instances, num_bodies, 13) + assert torch.allclose(body_state, expected_body_state, atol=1e-6, rtol=1e-6) + + # --- Test body_link_state_w --- + # Combines body_link_pose_w with body_link_vel_w + body_link_state = wp.to_torch(articulation_data.body_link_state_w) + + # Compute expected body_link_vel from com_vel (same as TestBodyLinkVelW) + body_link_vel = com_vel.clone() + body_link_vel[..., :3] += torch.linalg.cross( + body_link_vel[..., 3:], + math_utils.quat_apply(body_link_pose[..., 3:], -body_com_pos), + dim=-1, + ) + expected_body_link_state = torch.cat([body_link_pose, body_link_vel], dim=-1) + + assert body_link_state.shape == (num_instances, num_bodies, 13) + assert torch.allclose(body_link_state, expected_body_link_state, atol=1e-6, rtol=1e-6) + + # --- Test body_com_state_w --- + # Combines body_com_pose_w with body_com_vel_w + body_com_state = wp.to_torch(articulation_data.body_com_state_w) + + # Compute expected body_com_pose from body_link_pose and body_com_pos (same as TestBodyComPoseW) + body_com_quat_b = torch.zeros((num_instances, num_bodies, 4), device=device) + body_com_quat_b[..., 3] = 1.0 + body_com_pos_w, body_com_quat_w = math_utils.combine_frame_transforms( + body_link_pose[..., :3], body_link_pose[..., 3:], body_com_pos, body_com_quat_b + ) + expected_body_com_state = torch.cat([body_com_pos_w, body_com_quat_w, com_vel], dim=-1) + + assert body_com_state.shape == (num_instances, num_bodies, 13) + assert torch.allclose(body_com_state, expected_body_com_state, atol=1e-6, rtol=1e-6) + + +class TestBodyComAccW: + """Tests the body center of mass acceleration property. + + This value is derived from velocity finite differencing: (current_vel - previous_vel) / dt + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is correctly computed. + - Checks that the timestamp is updated correctly. + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str, initial_vel: torch.Tensor | None = None + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + + # Set initial velocities (these become _previous_body_com_vel) + if initial_vel is not None: + mock_view.set_mock_data( + link_velocities=wp.from_torch(initial_vel, dtype=wp.spatial_vectorf), + ) + else: + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_correctness(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that body_com_acc_w is correctly computed from velocity finite differencing.""" + # Initial velocity (becomes previous_velocity) + previous_vel = torch.rand((num_instances, num_bodies, 6), device=device) + articulation_data, mock_view = self._setup_method(num_instances, num_bodies, device, previous_vel) + + # Check the type and shape + assert articulation_data.body_com_acc_w.shape == (num_instances, num_bodies) + assert articulation_data.body_com_acc_w.dtype == wp.spatial_vectorf + + # dt is mocked as 0.01 + dt = 0.01 + + for i in range(10): + articulation_data._sim_timestamp = i + 1.0 + + # Generate new random velocity + current_vel = torch.rand((num_instances, num_bodies, 6), device=device) + mock_view.set_mock_data( + link_velocities=wp.from_torch(current_vel, dtype=wp.spatial_vectorf), + ) + + # Compute expected acceleration: (current - previous) / dt + expected_acc = (current_vel - previous_vel) / dt + + # Compare the computed value + assert torch.allclose(wp.to_torch(articulation_data.body_com_acc_w), expected_acc, atol=1e-5, rtol=1e-5) + # Update previous velocity + previous_vel = current_vel.clone() + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_timestamp_invalidation(self, mock_newton_manager, device: str): + """Test that data is invalidated when timestamp is updated.""" + initial_vel = torch.zeros((1, 1, 6), device=device) + articulation_data, mock_view = self._setup_method(1, 1, device, initial_vel) + + # Check initial timestamp + assert articulation_data._body_com_acc_w.timestamp == -1.0 + assert articulation_data._sim_timestamp == 0.0 + + # Request the property to trigger computation + value = wp.to_torch(articulation_data.body_com_acc_w).clone() + + # Check that buffer timestamp matches sim timestamp + assert articulation_data._body_com_acc_w.timestamp == articulation_data._sim_timestamp + + # Update mock data without changing sim timestamp + mock_view.set_mock_data( + link_velocities=wp.from_torch(torch.rand((1, 1, 6), device=device), dtype=wp.spatial_vectorf), + ) + + # Value should NOT change (cached) + assert torch.all(wp.to_torch(articulation_data.body_com_acc_w) == value) + + # Update sim timestamp + articulation_data._sim_timestamp = 1.0 + + # Buffer timestamp should now be stale + assert articulation_data._body_com_acc_w.timestamp != articulation_data._sim_timestamp + + # Value should now be recomputed (different from cached) + assert not torch.all(wp.to_torch(articulation_data.body_com_acc_w) == value) + + +class TestBodyComPoseB: + """Tests the body center of mass pose in body frame property. + + This value is generated from COM position with identity quaternion. + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value correctly combines position with identity quaternion. + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_body_com_pose_b(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that body_com_pose_b correctly generates pose from position with identity quaternion.""" + articulation_data, mock_view = self._setup_method(num_instances, num_bodies, device) + + # Check the type and shape + assert articulation_data.body_com_pose_b.shape == (num_instances, num_bodies) + assert articulation_data.body_com_pose_b.dtype == wp.transformf + + # Mock data is initialized to zeros for COM position + # Expected pose: [0, 0, 0, 0, 0, 0, 1] (position zeros, identity quaternion) + expected = torch.zeros((num_instances, num_bodies, 7), device=device) + expected[..., 6] = 1.0 # w component of identity quaternion + assert torch.all(wp.to_torch(articulation_data.body_com_pose_b) == expected) + + # Update COM position and verify + com_pos = torch.rand((num_instances, num_bodies, 3), device=device) + mock_view.set_mock_data( + body_com_pos=wp.from_torch(com_pos, dtype=wp.vec3f), + ) + + # Get the pose + pose = wp.to_torch(articulation_data.body_com_pose_b) + + # Expected: position from mock, identity quaternion + expected_pose = torch.zeros((num_instances, num_bodies, 7), device=device) + expected_pose[..., :3] = com_pos + expected_pose[..., 6] = 1.0 # w component + + assert torch.allclose(pose, expected_pose, atol=1e-6, rtol=1e-6) + + +# TODO: Update this test when body_incoming_joint_wrench_b support is added to Newton. +class TestBodyIncomingJointWrenchB: + """Tests the body incoming joint wrench property. + + Currently, this property raises NotImplementedError as joint wrenches + are not supported in Newton. + + Runs the following checks: + - Checks that the property raises NotImplementedError. + """ + + def _setup_method(self, num_instances: int, num_bodies: int, device: str) -> ArticulationData: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_not_implemented(self, mock_newton_manager, device: str): + """Test that body_incoming_joint_wrench_b raises NotImplementedError.""" + articulation_data = self._setup_method(1, 1, device) + + with pytest.raises(NotImplementedError): + _ = articulation_data.body_incoming_joint_wrench_b + + +## +# Test Cases -- Joint state properties. +## + + +class TestJointPosVel: + """Tests the joint position and velocity properties. + + These values are read directly from the simulation bindings. + + Tests the following properties: + - joint_pos + - joint_vel + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is a reference to the internal data. + """ + + def _setup_method( + self, num_instances: int, num_joints: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, num_joints, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_joint_pos_and_vel(self, mock_newton_manager, num_instances: int, num_joints: int, device: str): + """Test that joint_pos and joint_vel have correct type, shape, and reference behavior.""" + articulation_data, mock_view = self._setup_method(num_instances, num_joints, device) + + # --- Test joint_pos --- + # Check the type and shape + assert articulation_data.joint_pos.shape == (num_instances, num_joints) + assert articulation_data.joint_pos.dtype == wp.float32 + + # Mock data is initialized to zeros + expected = torch.zeros((num_instances, num_joints), device=device) + assert torch.all(wp.to_torch(articulation_data.joint_pos) == expected) + + # Get the property reference + joint_pos_ref = articulation_data.joint_pos + + # Assign a different value via property + articulation_data.joint_pos.fill_(1.0) + + # Check that the property returns the new value (reference behavior) + expected_ones = torch.ones((num_instances, num_joints), device=device) + assert torch.all(wp.to_torch(articulation_data.joint_pos) == expected_ones) + + # Assign a different value via reference + joint_pos_ref.fill_(2.0) + + # Check that the internal data has been updated + expected_twos = torch.ones((num_instances, num_joints), device=device) * 2.0 + assert torch.all(wp.to_torch(articulation_data.joint_pos) == expected_twos) + + # --- Test joint_vel --- + # Check the type and shape + assert articulation_data.joint_vel.shape == (num_instances, num_joints) + assert articulation_data.joint_vel.dtype == wp.float32 + + # Mock data is initialized to zeros + expected = torch.zeros((num_instances, num_joints), device=device) + assert torch.all(wp.to_torch(articulation_data.joint_vel) == expected) + + # Get the property reference + joint_vel_ref = articulation_data.joint_vel + + # Assign a different value via property + articulation_data.joint_vel.fill_(1.0) + + # Check that the property returns the new value (reference behavior) + expected_ones = torch.ones((num_instances, num_joints), device=device) + assert torch.all(wp.to_torch(articulation_data.joint_vel) == expected_ones) + + # Assign a different value via reference + joint_vel_ref.fill_(2.0) + + # Check that the internal data has been updated + expected_twos = torch.ones((num_instances, num_joints), device=device) * 2.0 + assert torch.all(wp.to_torch(articulation_data.joint_vel) == expected_twos) + + +class TestJointAcc: + """Tests the joint acceleration property. + + This value is derived from velocity finite differencing: (current_vel - previous_vel) / dt + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is correctly computed. + - Checks that the timestamp is updated correctly. + """ + + def _setup_method( + self, num_instances: int, num_joints: int, device: str, initial_vel: torch.Tensor | None = None + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, num_joints, device) + + # Set initial velocities (these become _previous_joint_vel) + if initial_vel is not None: + mock_view.set_mock_data( + dof_velocities=wp.from_torch(initial_vel, dtype=wp.float32), + ) + else: + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_correctness(self, mock_newton_manager, num_instances: int, num_joints: int, device: str): + """Test that joint_acc is correctly computed from velocity finite differencing.""" + # Initial velocity (becomes previous_velocity) + previous_vel = torch.rand((num_instances, num_joints), device=device) + articulation_data, mock_view = self._setup_method(num_instances, num_joints, device, previous_vel) + + # Check the type and shape + assert articulation_data.joint_acc.shape == (num_instances, num_joints) + assert articulation_data.joint_acc.dtype == wp.float32 + + # dt is mocked as 0.01 + dt = 0.01 + + for i in range(5): + articulation_data._sim_timestamp = i + 1.0 + + # Generate new random velocity + current_vel = torch.rand((num_instances, num_joints), device=device) + mock_view.set_mock_data( + dof_velocities=wp.from_torch(current_vel, dtype=wp.float32), + ) + + # Compute expected acceleration: (current - previous) / dt + expected_acc = (current_vel - previous_vel) / dt + + # Compare the computed value + assert torch.allclose(wp.to_torch(articulation_data.joint_acc), expected_acc, atol=1e-5, rtol=1e-5) + # Update previous velocity + previous_vel = current_vel.clone() + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_timestamp_invalidation(self, mock_newton_manager, device: str): + """Test that data is invalidated when timestamp is updated.""" + initial_vel = torch.zeros((1, 1), device=device) + articulation_data, mock_view = self._setup_method(1, 1, device, initial_vel) + + # Check initial timestamp + assert articulation_data._joint_acc.timestamp == -1.0 + assert articulation_data._sim_timestamp == 0.0 + + # Request the property to trigger computation + value = wp.to_torch(articulation_data.joint_acc).clone() + + # Check that buffer timestamp matches sim timestamp + assert articulation_data._joint_acc.timestamp == articulation_data._sim_timestamp + + # Update mock data without changing sim timestamp + mock_view.set_mock_data( + dof_velocities=wp.from_torch(torch.rand((1, 1), device=device), dtype=wp.float32), + ) + + # Value should NOT change (cached) + assert torch.all(wp.to_torch(articulation_data.joint_acc) == value) + + # Update sim timestamp + articulation_data._sim_timestamp = 1.0 + + # Buffer timestamp should now be stale + assert articulation_data._joint_acc.timestamp != articulation_data._sim_timestamp + + # Value should now be recomputed (different from cached) + assert not torch.all(wp.to_torch(articulation_data.joint_acc) == value) + + +## +# Test Cases -- Derived properties. +## + + +class TestProjectedGravityB: + """Tests the projected gravity in body frame property. + + This value is derived by projecting the gravity vector onto the body frame. + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is correctly computed. + - Checks that the timestamp is updated correctly. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_correctness(self, mock_newton_manager, num_instances: int, device: str): + """Test that projected_gravity_b is correctly computed.""" + articulation_data, mock_view = self._setup_method(num_instances, device) + + # Check the type and shape + assert articulation_data.projected_gravity_b.shape == (num_instances,) + assert articulation_data.projected_gravity_b.dtype == wp.vec3f + + # Gravity direction (normalized) + gravity_dir = torch.tensor([0.0, 0.0, -1.0], device=device) + + for i in range(10): + articulation_data._sim_timestamp = i + 1.0 + # Generate random root pose with normalized quaternion + root_pose = torch.zeros((num_instances, 7), device=device) + root_pose[:, :3] = torch.rand((num_instances, 3), device=device) + root_pose[:, 3:] = torch.randn((num_instances, 4), device=device) + root_pose[:, 3:] = torch.nn.functional.normalize(root_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_pose, dtype=wp.transformf), + ) + + # Compute expected projected gravity: quat_apply(quat, gravity_dir) + # This rotates gravity from world to body frame + expected = math_utils.quat_apply_inverse(root_pose[:, 3:], gravity_dir.expand(num_instances, 3)) + + # Compare the computed value + assert torch.allclose(wp.to_torch(articulation_data.projected_gravity_b), expected, atol=1e-4, rtol=1e-4) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_timestamp_invalidation(self, mock_newton_manager, device: str): + """Test that data is invalidated when timestamp is updated.""" + articulation_data, mock_view = self._setup_method(1, device) + + # Check initial timestamp + assert articulation_data._projected_gravity_b.timestamp == -1.0 + assert articulation_data._sim_timestamp == 0.0 + + # Request the property to trigger computation + value = wp.to_torch(articulation_data.projected_gravity_b).clone() + + # Check that buffer timestamp matches sim timestamp + assert articulation_data._projected_gravity_b.timestamp == articulation_data._sim_timestamp + + # Update mock data without changing sim timestamp + new_pose = torch.zeros((1, 7), device=device) + new_pose[:, 3:] = torch.randn((1, 4), device=device) + new_pose[:, 3:] = torch.nn.functional.normalize(new_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + mock_view.set_mock_data( + root_transforms=wp.from_torch(new_pose, dtype=wp.transformf), + ) + + # Value should NOT change (cached) + assert torch.all(wp.to_torch(articulation_data.projected_gravity_b) == value) + + # Update sim timestamp + articulation_data._sim_timestamp = 1.0 + + # Buffer timestamp should now be stale + assert articulation_data._projected_gravity_b.timestamp != articulation_data._sim_timestamp + + # Value should now be recomputed (different from cached) + assert not torch.all(wp.to_torch(articulation_data.projected_gravity_b) == value) + + +class TestHeadingW: + """Tests the heading in world frame property. + + This value is derived by computing the yaw angle from the forward direction. + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is correctly computed. + - Checks that the timestamp is updated correctly. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_correctness(self, mock_newton_manager, num_instances: int, device: str): + """Test that heading_w is correctly computed.""" + articulation_data, mock_view = self._setup_method(num_instances, device) + + # Check the type and shape + assert articulation_data.heading_w.shape == (num_instances,) + assert articulation_data.heading_w.dtype == wp.float32 + + # Forward direction in body frame + forward_vec_b = torch.tensor([1.0, 0.0, 0.0], device=device) + + for i in range(10): + articulation_data._sim_timestamp = i + 1.0 + # Generate random root pose with normalized quaternion + root_pose = torch.zeros((num_instances, 7), device=device) + root_pose[:, :3] = torch.rand((num_instances, 3), device=device) + root_pose[:, 3:] = torch.randn((num_instances, 4), device=device) + root_pose[:, 3:] = torch.nn.functional.normalize(root_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_pose, dtype=wp.transformf), + ) + print(articulation_data._sim_bind_root_link_pose_w) + print(articulation_data.FORWARD_VEC_B) + # Compute expected heading: atan2(rotated_forward.y, rotated_forward.x) + rotated_forward = math_utils.quat_apply(root_pose[:, 3:], forward_vec_b.expand(num_instances, 3)) + expected = torch.atan2(rotated_forward[:, 1], rotated_forward[:, 0]) + print(f"expected: {expected}") + + # Compare the computed value + assert torch.allclose(wp.to_torch(articulation_data.heading_w), expected, atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_timestamp_invalidation(self, mock_newton_manager, device: str): + """Test that data is invalidated when timestamp is updated.""" + articulation_data, mock_view = self._setup_method(1, device) + + # Check initial timestamp + assert articulation_data._heading_w.timestamp == -1.0 + assert articulation_data._sim_timestamp == 0.0 + + # Request the property to trigger computation + value = wp.to_torch(articulation_data.heading_w).clone() + + # Check that buffer timestamp matches sim timestamp + assert articulation_data._heading_w.timestamp == articulation_data._sim_timestamp + + # Update mock data without changing sim timestamp + new_pose = torch.zeros((1, 7), device=device) + new_pose[:, 3:] = torch.randn((1, 4), device=device) + new_pose[:, 3:] = torch.nn.functional.normalize(new_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + mock_view.set_mock_data( + root_transforms=wp.from_torch(new_pose, dtype=wp.transformf), + ) + + # Value should NOT change (cached) + assert torch.all(wp.to_torch(articulation_data.heading_w) == value) + + # Update sim timestamp + articulation_data._sim_timestamp = 1.0 + + # Buffer timestamp should now be stale + assert articulation_data._heading_w.timestamp != articulation_data._sim_timestamp + + # Value should now be recomputed (different from cached) + assert not torch.all(wp.to_torch(articulation_data.heading_w) == value) + + +class TestRootLinkVelB: + """Tests the root link velocity in body frame properties. + + Tests the following properties: + - root_link_vel_b: velocity projected to body frame + - root_link_lin_vel_b: linear velocity slice (first 3 components) + - root_link_ang_vel_b: angular velocity slice (last 3 components) + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is correctly computed. + - Checks that lin/ang velocities are correct slices. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_correctness(self, mock_newton_manager, num_instances: int, device: str): + """Test that root_link_vel_b and its slices are correctly computed.""" + articulation_data, mock_view = self._setup_method(num_instances, device) + + # Check types and shapes + assert articulation_data.root_link_vel_b.shape == (num_instances,) + assert articulation_data.root_link_vel_b.dtype == wp.spatial_vectorf + + assert articulation_data.root_link_lin_vel_b.shape == (num_instances,) + assert articulation_data.root_link_lin_vel_b.dtype == wp.vec3f + + assert articulation_data.root_link_ang_vel_b.shape == (num_instances,) + assert articulation_data.root_link_ang_vel_b.dtype == wp.vec3f + + for i in range(5): + articulation_data._sim_timestamp = i + 1.0 + + # Generate random root pose with normalized quaternion + root_pose = torch.zeros((num_instances, 7), device=device) + root_pose[:, :3] = torch.rand((num_instances, 3), device=device) + root_pose[:, 3:] = torch.randn((num_instances, 4), device=device) + root_pose[:, 3:] = torch.nn.functional.normalize(root_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + + # Generate random COM velocity and body COM position + com_vel = torch.rand((num_instances, 6), device=device) + body_com_pos = torch.rand((num_instances, 1, 3), device=device) + + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_pose, dtype=wp.transformf), + root_velocities=wp.from_torch(com_vel, dtype=wp.spatial_vectorf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # Compute expected root_link_vel_w (same as TestRootLinkVelW) + root_link_vel_w = com_vel.clone() + root_link_vel_w[:, :3] += torch.linalg.cross( + root_link_vel_w[:, 3:], + math_utils.quat_apply(root_pose[:, 3:], -body_com_pos[:, 0]), + dim=-1, + ) + + # Project to body frame using quat_rotate_inv + # Linear velocity: quat_rotate_inv(quat, lin_vel) + # Angular velocity: quat_rotate_inv(quat, ang_vel) + lin_vel_b = math_utils.quat_apply_inverse(root_pose[:, 3:], root_link_vel_w[:, :3]) + ang_vel_b = math_utils.quat_apply_inverse(root_pose[:, 3:], root_link_vel_w[:, 3:]) + expected_vel_b = torch.cat([lin_vel_b, ang_vel_b], dim=-1) + + # Get computed values + computed_vel_b = wp.to_torch(articulation_data.root_link_vel_b) + computed_lin_vel_b = wp.to_torch(articulation_data.root_link_lin_vel_b) + computed_ang_vel_b = wp.to_torch(articulation_data.root_link_ang_vel_b) + + # Compare full velocity + assert torch.allclose(computed_vel_b, expected_vel_b, atol=1e-6, rtol=1e-6) + + # Check that lin/ang velocities are correct slices + assert torch.allclose(computed_lin_vel_b, computed_vel_b[:, :3], atol=1e-6, rtol=1e-6) + assert torch.allclose(computed_ang_vel_b, computed_vel_b[:, 3:], atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_timestamp_invalidation(self, mock_newton_manager, device: str): + """Test that data is invalidated when timestamp is updated.""" + articulation_data, mock_view = self._setup_method(1, device) + + # Check initial timestamp + assert articulation_data._root_link_vel_b.timestamp == -1.0 + assert articulation_data._sim_timestamp == 0.0 + + # Request the property to trigger computation + value = wp.to_torch(articulation_data.root_link_vel_b).clone() + + # Check that buffer timestamp matches sim timestamp + assert articulation_data._root_link_vel_b.timestamp == articulation_data._sim_timestamp + + # Update mock data without changing sim timestamp + new_pose = torch.zeros((1, 7), device=device) + new_pose[:, 3:] = torch.randn((1, 4), device=device) + new_pose[:, 3:] = torch.nn.functional.normalize(new_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + mock_view.set_mock_data( + root_transforms=wp.from_torch(new_pose, dtype=wp.transformf), + root_velocities=wp.from_torch(torch.rand((1, 6), device=device), dtype=wp.spatial_vectorf), + ) + + # Value should NOT change (cached) + assert torch.all(wp.to_torch(articulation_data.root_link_vel_b) == value) + + # Update sim timestamp + articulation_data._sim_timestamp = 1.0 + + # Buffer timestamp should now be stale + assert articulation_data._root_link_vel_b.timestamp != articulation_data._sim_timestamp + + # Value should now be recomputed (different from cached) + assert not torch.all(wp.to_torch(articulation_data.root_link_vel_b) == value) + + +class TestRootComVelB: + """Tests the root center of mass velocity in body frame properties. + + Tests the following properties: + - root_com_vel_b: COM velocity projected to body frame + - root_com_lin_vel_b: linear velocity slice (first 3 components) + - root_com_ang_vel_b: angular velocity slice (last 3 components) + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that the returned value is correctly computed. + - Checks that lin/ang velocities are correct slices. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_correctness(self, mock_newton_manager, num_instances: int, device: str): + """Test that root_com_vel_b and its slices are correctly computed.""" + articulation_data, mock_view = self._setup_method(num_instances, device) + + # Check types and shapes + assert articulation_data.root_com_vel_b.shape == (num_instances,) + assert articulation_data.root_com_vel_b.dtype == wp.spatial_vectorf + + assert articulation_data.root_com_lin_vel_b.shape == (num_instances,) + assert articulation_data.root_com_lin_vel_b.dtype == wp.vec3f + + assert articulation_data.root_com_ang_vel_b.shape == (num_instances,) + assert articulation_data.root_com_ang_vel_b.dtype == wp.vec3f + + for i in range(5): + articulation_data._sim_timestamp = i + 1.0 + + # Generate random root pose with normalized quaternion + root_pose = torch.zeros((num_instances, 7), device=device) + root_pose[:, :3] = torch.rand((num_instances, 3), device=device) + root_pose[:, 3:] = torch.randn((num_instances, 4), device=device) + root_pose[:, 3:] = torch.nn.functional.normalize(root_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + + # Generate random COM velocity (this is root_com_vel_w from simulation) + com_vel_w = torch.rand((num_instances, 6), device=device) + + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_pose, dtype=wp.transformf), + root_velocities=wp.from_torch(com_vel_w, dtype=wp.spatial_vectorf), + ) + + # Project COM velocity to body frame using quat_rotate_inv (quat_conjugate + quat_apply) + lin_vel_b = math_utils.quat_apply_inverse(root_pose[:, 3:], com_vel_w[:, :3]) + ang_vel_b = math_utils.quat_apply_inverse(root_pose[:, 3:], com_vel_w[:, 3:]) + expected_vel_b = torch.cat([lin_vel_b, ang_vel_b], dim=-1) + + # Get computed values + computed_vel_b = wp.to_torch(articulation_data.root_com_vel_b) + computed_lin_vel_b = wp.to_torch(articulation_data.root_com_lin_vel_b) + computed_ang_vel_b = wp.to_torch(articulation_data.root_com_ang_vel_b) + + # Compare full velocity + assert torch.allclose(computed_vel_b, expected_vel_b, atol=1e-6, rtol=1e-6) + + # Check that lin/ang velocities are correct slices + assert torch.allclose(computed_lin_vel_b, computed_vel_b[:, :3], atol=1e-6, rtol=1e-6) + assert torch.allclose(computed_ang_vel_b, computed_vel_b[:, 3:], atol=1e-6, rtol=1e-6) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_timestamp_invalidation(self, mock_newton_manager, device: str): + """Test that data is invalidated when timestamp is updated.""" + articulation_data, mock_view = self._setup_method(1, device) + + # Check initial timestamp + assert articulation_data._root_com_vel_b.timestamp == -1.0 + assert articulation_data._sim_timestamp == 0.0 + + # Request the property to trigger computation + value = wp.to_torch(articulation_data.root_com_vel_b).clone() + + # Check that buffer timestamp matches sim timestamp + assert articulation_data._root_com_vel_b.timestamp == articulation_data._sim_timestamp + + # Update mock data without changing sim timestamp + new_pose = torch.zeros((1, 7), device=device) + new_pose[:, 3:] = torch.randn((1, 4), device=device) + new_pose[:, 3:] = torch.nn.functional.normalize(new_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + mock_view.set_mock_data( + root_transforms=wp.from_torch(new_pose, dtype=wp.transformf), + root_velocities=wp.from_torch(torch.rand((1, 6), device=device), dtype=wp.spatial_vectorf), + ) + + # Value should NOT change (cached) + assert torch.all(wp.to_torch(articulation_data.root_com_vel_b) == value) + + # Update sim timestamp + articulation_data._sim_timestamp = 1.0 + + # Buffer timestamp should now be stale + assert articulation_data._root_com_vel_b.timestamp != articulation_data._sim_timestamp + + # Value should now be recomputed (different from cached) + assert not torch.all(wp.to_torch(articulation_data.root_com_vel_b) == value) + + +## +# Test Cases -- Sliced properties. +## + + +class TestRootSlicedProperties: + """Tests the root sliced properties. + + These properties extract position, quaternion, linear velocity, or angular velocity + from the full pose/velocity arrays. + + Tests the following properties: + - root_link_pos_w, root_link_quat_w (from root_link_pose_w) + - root_link_lin_vel_w, root_link_ang_vel_w (from root_link_vel_w) + - root_com_pos_w, root_com_quat_w (from root_com_pose_w) + - root_com_lin_vel_w, root_com_ang_vel_w (from root_com_vel_w) + + For each property, we only check that they are the correct slice of the parent property. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_all_root_sliced_properties(self, mock_newton_manager, num_instances: int, device: str): + """Test that all root sliced properties are correct slices of their parent properties.""" + articulation_data, mock_view = self._setup_method(num_instances, device) + + # Set up random mock data to ensure non-trivial values + articulation_data._sim_timestamp = 1.0 + + root_pose = torch.zeros((num_instances, 7), device=device) + root_pose[:, :3] = torch.rand((num_instances, 3), device=device) + root_pose[:, 3:] = torch.randn((num_instances, 4), device=device) + root_pose[:, 3:] = torch.nn.functional.normalize(root_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + + com_vel = torch.rand((num_instances, 6), device=device) + body_com_pos = torch.rand((num_instances, 1, 3), device=device) + + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_pose, dtype=wp.transformf), + root_velocities=wp.from_torch(com_vel, dtype=wp.spatial_vectorf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # --- Test root_link_pose_w slices --- + root_link_pose = wp.to_torch(articulation_data.root_link_pose_w) + root_link_pos = wp.to_torch(articulation_data.root_link_pos_w) + root_link_quat = wp.to_torch(articulation_data.root_link_quat_w) + + assert root_link_pos.shape == (num_instances, 3) + assert root_link_quat.shape == (num_instances, 4) + assert torch.allclose(root_link_pos, root_link_pose[:, :3], atol=1e-6) + assert torch.allclose(root_link_quat, root_link_pose[:, 3:], atol=1e-6) + + # --- Test root_link_vel_w slices --- + root_link_vel = wp.to_torch(articulation_data.root_link_vel_w) + root_link_lin_vel = wp.to_torch(articulation_data.root_link_lin_vel_w) + root_link_ang_vel = wp.to_torch(articulation_data.root_link_ang_vel_w) + + assert root_link_lin_vel.shape == (num_instances, 3) + assert root_link_ang_vel.shape == (num_instances, 3) + assert torch.allclose(root_link_lin_vel, root_link_vel[:, :3], atol=1e-6) + assert torch.allclose(root_link_ang_vel, root_link_vel[:, 3:], atol=1e-6) + + # --- Test root_com_pose_w slices --- + root_com_pose = wp.to_torch(articulation_data.root_com_pose_w) + root_com_pos = wp.to_torch(articulation_data.root_com_pos_w) + root_com_quat = wp.to_torch(articulation_data.root_com_quat_w) + + assert root_com_pos.shape == (num_instances, 3) + assert root_com_quat.shape == (num_instances, 4) + assert torch.allclose(root_com_pos, root_com_pose[:, :3], atol=1e-6) + assert torch.allclose(root_com_quat, root_com_pose[:, 3:], atol=1e-6) + + # --- Test root_com_vel_w slices --- + root_com_vel = wp.to_torch(articulation_data.root_com_vel_w) + root_com_lin_vel = wp.to_torch(articulation_data.root_com_lin_vel_w) + root_com_ang_vel = wp.to_torch(articulation_data.root_com_ang_vel_w) + + assert root_com_lin_vel.shape == (num_instances, 3) + assert root_com_ang_vel.shape == (num_instances, 3) + assert torch.allclose(root_com_lin_vel, root_com_vel[:, :3], atol=1e-6) + assert torch.allclose(root_com_ang_vel, root_com_vel[:, 3:], atol=1e-6) + + +class TestBodySlicedProperties: + """Tests the body sliced properties. + + These properties extract position, quaternion, linear velocity, or angular velocity + from the full pose/velocity arrays. + + Tests the following properties: + - body_link_pos_w, body_link_quat_w (from body_link_pose_w) + - body_link_lin_vel_w, body_link_ang_vel_w (from body_link_vel_w) + - body_com_pos_w, body_com_quat_w (from body_com_pose_w) + - body_com_lin_vel_w, body_com_ang_vel_w (from body_com_vel_w) + + For each property, we only check that they are the correct slice of the parent property. + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_all_body_sliced_properties(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that all body sliced properties are correct slices of their parent properties.""" + articulation_data, mock_view = self._setup_method(num_instances, num_bodies, device) + + # Set up random mock data to ensure non-trivial values + articulation_data._sim_timestamp = 1.0 + + body_pose = torch.zeros((num_instances, num_bodies, 7), device=device) + body_pose[..., :3] = torch.rand((num_instances, num_bodies, 3), device=device) + body_pose[..., 3:] = torch.randn((num_instances, num_bodies, 4), device=device) + body_pose[..., 3:] = torch.nn.functional.normalize(body_pose[..., 3:], p=2.0, dim=-1, eps=1e-12) + + body_vel = torch.rand((num_instances, num_bodies, 6), device=device) + body_com_pos = torch.rand((num_instances, num_bodies, 3), device=device) + + mock_view.set_mock_data( + link_transforms=wp.from_torch(body_pose, dtype=wp.transformf), + link_velocities=wp.from_torch(body_vel, dtype=wp.spatial_vectorf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # --- Test body_link_pose_w slices --- + body_link_pose = wp.to_torch(articulation_data.body_link_pose_w) + body_link_pos = wp.to_torch(articulation_data.body_link_pos_w) + body_link_quat = wp.to_torch(articulation_data.body_link_quat_w) + + assert body_link_pos.shape == (num_instances, num_bodies, 3) + assert body_link_quat.shape == (num_instances, num_bodies, 4) + assert torch.allclose(body_link_pos, body_link_pose[..., :3], atol=1e-6) + assert torch.allclose(body_link_quat, body_link_pose[..., 3:], atol=1e-6) + + # --- Test body_link_vel_w slices --- + body_link_vel = wp.to_torch(articulation_data.body_link_vel_w) + body_link_lin_vel = wp.to_torch(articulation_data.body_link_lin_vel_w) + body_link_ang_vel = wp.to_torch(articulation_data.body_link_ang_vel_w) + + assert body_link_lin_vel.shape == (num_instances, num_bodies, 3) + assert body_link_ang_vel.shape == (num_instances, num_bodies, 3) + assert torch.allclose(body_link_lin_vel, body_link_vel[..., :3], atol=1e-6) + assert torch.allclose(body_link_ang_vel, body_link_vel[..., 3:], atol=1e-6) + + # --- Test body_com_pose_w slices --- + body_com_pose = wp.to_torch(articulation_data.body_com_pose_w) + body_com_pos_w = wp.to_torch(articulation_data.body_com_pos_w) + body_com_quat_w = wp.to_torch(articulation_data.body_com_quat_w) + + assert body_com_pos_w.shape == (num_instances, num_bodies, 3) + assert body_com_quat_w.shape == (num_instances, num_bodies, 4) + assert torch.allclose(body_com_pos_w, body_com_pose[..., :3], atol=1e-6) + assert torch.allclose(body_com_quat_w, body_com_pose[..., 3:], atol=1e-6) + + # --- Test body_com_vel_w slices --- + body_com_vel = wp.to_torch(articulation_data.body_com_vel_w) + body_com_lin_vel = wp.to_torch(articulation_data.body_com_lin_vel_w) + body_com_ang_vel = wp.to_torch(articulation_data.body_com_ang_vel_w) + + assert body_com_lin_vel.shape == (num_instances, num_bodies, 3) + assert body_com_ang_vel.shape == (num_instances, num_bodies, 3) + assert torch.allclose(body_com_lin_vel, body_com_vel[..., :3], atol=1e-6) + assert torch.allclose(body_com_ang_vel, body_com_vel[..., 3:], atol=1e-6) + + +class TestBodyComPosQuatB: + """Tests the body center of mass position and quaternion in body frame properties. + + Tests the following properties: + - body_com_pos_b: COM position in body frame (direct sim binding) + - body_com_quat_b: COM orientation in body frame (derived from body_com_pose_b) + + Runs the following checks: + - Checks that the returned values have the correct type and shape. + - Checks that body_com_pos_b returns the simulation data. + - Checks that body_com_quat_b is the quaternion slice of body_com_pose_b. + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_body_com_pos_and_quat_b(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that body_com_pos_b and body_com_quat_b have correct types, shapes, and values.""" + articulation_data, mock_view = self._setup_method(num_instances, num_bodies, device) + + # --- Test body_com_pos_b --- + # Check the type and shape + assert articulation_data.body_com_pos_b.shape == (num_instances, num_bodies) + assert articulation_data.body_com_pos_b.dtype == wp.vec3f + + # Mock data is initialized to zeros + expected_pos = torch.zeros((num_instances, num_bodies, 3), device=device) + assert torch.all(wp.to_torch(articulation_data.body_com_pos_b) == expected_pos) + + # Update with random COM positions + com_pos = torch.rand((num_instances, num_bodies, 3), device=device) + mock_view.set_mock_data( + body_com_pos=wp.from_torch(com_pos, dtype=wp.vec3f), + ) + + # Check that the property returns the mock data + assert torch.allclose(wp.to_torch(articulation_data.body_com_pos_b), com_pos, atol=1e-6) + + # Verify reference behavior + body_com_pos_ref = articulation_data.body_com_pos_b + articulation_data.body_com_pos_b.fill_(1.0) + expected_ones = torch.ones((num_instances, num_bodies, 3), device=device) + assert torch.all(wp.to_torch(articulation_data.body_com_pos_b) == expected_ones) + body_com_pos_ref.fill_(2.0) + expected_twos = torch.ones((num_instances, num_bodies, 3), device=device) * 2.0 + assert torch.all(wp.to_torch(articulation_data.body_com_pos_b) == expected_twos) + + # --- Test body_com_quat_b --- + # Check the type and shape + assert articulation_data.body_com_quat_b.shape == (num_instances, num_bodies) + assert articulation_data.body_com_quat_b.dtype == wp.quatf + + # body_com_quat_b is derived from body_com_pose_b which uses identity quaternion + # body_com_pose_b = [body_com_pos_b, identity_quat] + # So body_com_quat_b should be identity quaternion (0, 0, 0, 1) + body_com_quat = wp.to_torch(articulation_data.body_com_quat_b) + expected_quat = torch.zeros((num_instances, num_bodies, 4), device=device) + expected_quat[..., 3] = 1.0 # w component of identity quaternion + + assert torch.allclose(body_com_quat, expected_quat, atol=1e-6) + + +## +# Test Cases -- Backward compatibility. +## + + +# TODO: Remove this test case in the future. +class TestDefaultRootState: + """Tests the deprecated default_root_state property. + + This property combines default_root_pose and default_root_vel into a vec13f state. + It is deprecated in favor of using default_root_pose and default_root_vel directly. + + Runs the following checks: + - Checks that the returned value has the correct type and shape. + - Checks that it correctly combines default_root_pose and default_root_vel. + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_default_root_state(self, mock_newton_manager, num_instances: int, device: str): + """Test that default_root_state correctly combines pose and velocity.""" + articulation_data, _ = self._setup_method(num_instances, device) + + # Check the type and shape + assert articulation_data.default_root_state.shape == (num_instances,) + + # Get the combined state + default_state = wp.to_torch(articulation_data.default_root_state) + assert default_state.shape == (num_instances, 13) + + # Get the individual components + default_pose = wp.to_torch(articulation_data.default_root_pose) + default_vel = wp.to_torch(articulation_data.default_root_vel) + + # Verify the state is the concatenation of pose and velocity + expected_state = torch.cat([default_pose, default_vel], dim=-1) + assert torch.allclose(default_state, expected_state, atol=1e-6) + + # Modify default_root_pose and default_root_vel and verify the state updates + new_pose = torch.zeros((num_instances, 7), device=device) + new_pose[:, :3] = torch.rand((num_instances, 3), device=device) + new_pose[:, 3:] = torch.randn((num_instances, 4), device=device) + new_pose[:, 3:] = torch.nn.functional.normalize(new_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + + new_vel = torch.rand((num_instances, 6), device=device) + + # Set the new values + articulation_data.default_root_pose.assign(wp.from_torch(new_pose, dtype=wp.transformf)) + articulation_data.default_root_vel.assign(wp.from_torch(new_vel, dtype=wp.spatial_vectorf)) + + # Verify the state reflects the new values + updated_state = wp.to_torch(articulation_data.default_root_state) + expected_updated_state = torch.cat([new_pose, new_vel], dim=-1) + assert torch.allclose(updated_state, expected_updated_state, atol=1e-6) + + +# TODO: Remove this test case in the future. +class TestDeprecatedRootProperties: + """Tests the deprecated root pose/velocity properties. + + These are backward compatibility aliases that just return the corresponding new property. + + Tests the following deprecated -> new property mappings: + - root_pose_w -> root_link_pose_w + - root_pos_w -> root_link_pos_w + - root_quat_w -> root_link_quat_w + - root_vel_w -> root_com_vel_w + - root_lin_vel_w -> root_com_lin_vel_w + - root_ang_vel_w -> root_com_ang_vel_w + - root_lin_vel_b -> root_com_lin_vel_b + - root_ang_vel_b -> root_com_ang_vel_b + """ + + def _setup_method(self, num_instances: int, device: str) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_all_deprecated_root_properties(self, mock_newton_manager, num_instances: int, device: str): + """Test that all deprecated root properties match their replacements.""" + articulation_data, mock_view = self._setup_method(num_instances, device) + + # Set up random mock data to ensure non-trivial values + articulation_data._sim_timestamp = 1.0 + + root_pose = torch.zeros((num_instances, 7), device=device) + root_pose[:, :3] = torch.rand((num_instances, 3), device=device) + root_pose[:, 3:] = torch.randn((num_instances, 4), device=device) + root_pose[:, 3:] = torch.nn.functional.normalize(root_pose[:, 3:], p=2.0, dim=-1, eps=1e-12) + + com_vel = torch.rand((num_instances, 6), device=device) + body_com_pos = torch.rand((num_instances, 1, 3), device=device) + + mock_view.set_mock_data( + root_transforms=wp.from_torch(root_pose, dtype=wp.transformf), + root_velocities=wp.from_torch(com_vel, dtype=wp.spatial_vectorf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # --- Test root_pose_w -> root_link_pose_w --- + assert torch.allclose( + wp.to_torch(articulation_data.root_pose_w), + wp.to_torch(articulation_data.root_link_pose_w), + atol=1e-6, + ) + + # --- Test root_pos_w -> root_link_pos_w --- + assert torch.allclose( + wp.to_torch(articulation_data.root_pos_w), + wp.to_torch(articulation_data.root_link_pos_w), + atol=1e-6, + ) + + # --- Test root_quat_w -> root_link_quat_w --- + assert torch.allclose( + wp.to_torch(articulation_data.root_quat_w), + wp.to_torch(articulation_data.root_link_quat_w), + atol=1e-6, + ) + + # --- Test root_vel_w -> root_com_vel_w --- + assert torch.allclose( + wp.to_torch(articulation_data.root_vel_w), + wp.to_torch(articulation_data.root_com_vel_w), + atol=1e-6, + ) + + # --- Test root_lin_vel_w -> root_com_lin_vel_w --- + assert torch.allclose( + wp.to_torch(articulation_data.root_lin_vel_w), + wp.to_torch(articulation_data.root_com_lin_vel_w), + atol=1e-6, + ) + + # --- Test root_ang_vel_w -> root_com_ang_vel_w --- + assert torch.allclose( + wp.to_torch(articulation_data.root_ang_vel_w), + wp.to_torch(articulation_data.root_com_ang_vel_w), + atol=1e-6, + ) + + # --- Test root_lin_vel_b -> root_com_lin_vel_b --- + assert torch.allclose( + wp.to_torch(articulation_data.root_lin_vel_b), + wp.to_torch(articulation_data.root_com_lin_vel_b), + atol=1e-6, + ) + + # --- Test root_ang_vel_b -> root_com_ang_vel_b --- + assert torch.allclose( + wp.to_torch(articulation_data.root_ang_vel_b), + wp.to_torch(articulation_data.root_com_ang_vel_b), + atol=1e-6, + ) + + +class TestDeprecatedBodyProperties: + """Tests the deprecated body pose/velocity/acceleration properties. + + These are backward compatibility aliases that just return the corresponding new property. + + Tests the following deprecated -> new property mappings: + - body_pose_w -> body_link_pose_w + - body_pos_w -> body_link_pos_w + - body_quat_w -> body_link_quat_w + - body_vel_w -> body_com_vel_w + - body_lin_vel_w -> body_com_lin_vel_w + - body_ang_vel_w -> body_com_ang_vel_w + - body_acc_w -> body_com_acc_w + - body_lin_acc_w -> body_com_lin_acc_w + - body_ang_acc_w -> body_com_ang_acc_w + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_all_deprecated_body_properties( + self, mock_newton_manager, num_instances: int, num_bodies: int, device: str + ): + """Test that all deprecated body properties match their replacements.""" + articulation_data, mock_view = self._setup_method(num_instances, num_bodies, device) + + # Set up random mock data to ensure non-trivial values + articulation_data._sim_timestamp = 1.0 + + body_pose = torch.zeros((num_instances, num_bodies, 7), device=device) + body_pose[..., :3] = torch.rand((num_instances, num_bodies, 3), device=device) + body_pose[..., 3:] = torch.randn((num_instances, num_bodies, 4), device=device) + body_pose[..., 3:] = torch.nn.functional.normalize(body_pose[..., 3:], p=2.0, dim=-1, eps=1e-12) + + body_vel = torch.rand((num_instances, num_bodies, 6), device=device) + body_com_pos = torch.rand((num_instances, num_bodies, 3), device=device) + + mock_view.set_mock_data( + link_transforms=wp.from_torch(body_pose, dtype=wp.transformf), + link_velocities=wp.from_torch(body_vel, dtype=wp.spatial_vectorf), + body_com_pos=wp.from_torch(body_com_pos, dtype=wp.vec3f), + ) + + # --- Test body_pose_w -> body_link_pose_w --- + assert torch.allclose( + wp.to_torch(articulation_data.body_pose_w), + wp.to_torch(articulation_data.body_link_pose_w), + atol=1e-6, + ) + + # --- Test body_pos_w -> body_link_pos_w --- + assert torch.allclose( + wp.to_torch(articulation_data.body_pos_w), + wp.to_torch(articulation_data.body_link_pos_w), + atol=1e-6, + ) + + # --- Test body_quat_w -> body_link_quat_w --- + assert torch.allclose( + wp.to_torch(articulation_data.body_quat_w), + wp.to_torch(articulation_data.body_link_quat_w), + atol=1e-6, + ) + + # --- Test body_vel_w -> body_com_vel_w --- + assert torch.allclose( + wp.to_torch(articulation_data.body_vel_w), + wp.to_torch(articulation_data.body_com_vel_w), + atol=1e-6, + ) + + # --- Test body_lin_vel_w -> body_com_lin_vel_w --- + assert torch.allclose( + wp.to_torch(articulation_data.body_lin_vel_w), + wp.to_torch(articulation_data.body_com_lin_vel_w), + atol=1e-6, + ) + + # --- Test body_ang_vel_w -> body_com_ang_vel_w --- + assert torch.allclose( + wp.to_torch(articulation_data.body_ang_vel_w), + wp.to_torch(articulation_data.body_com_ang_vel_w), + atol=1e-6, + ) + + # --- Test body_acc_w -> body_com_acc_w --- + assert torch.allclose( + wp.to_torch(articulation_data.body_acc_w), + wp.to_torch(articulation_data.body_com_acc_w), + atol=1e-6, + ) + + # --- Test body_lin_acc_w -> body_com_lin_acc_w --- + assert torch.allclose( + wp.to_torch(articulation_data.body_lin_acc_w), + wp.to_torch(articulation_data.body_com_lin_acc_w), + atol=1e-6, + ) + + # --- Test body_ang_acc_w -> body_com_ang_acc_w --- + assert torch.allclose( + wp.to_torch(articulation_data.body_ang_acc_w), + wp.to_torch(articulation_data.body_com_ang_acc_w), + atol=1e-6, + ) + + +class TestDeprecatedComProperties: + """Tests the deprecated COM pose properties. + + Tests the following deprecated -> new property mappings: + - com_pos_b -> body_com_pos_b + - com_quat_b -> body_com_quat_b + """ + + def _setup_method( + self, num_instances: int, num_bodies: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, num_bodies, 1, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_bodies", [1, 3]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_deprecated_com_properties(self, mock_newton_manager, num_instances: int, num_bodies: int, device: str): + """Test that deprecated COM properties match their replacements.""" + articulation_data, mock_view = self._setup_method(num_instances, num_bodies, device) + + # Set up random mock data + com_pos = torch.rand((num_instances, num_bodies, 3), device=device) + mock_view.set_mock_data( + body_com_pos=wp.from_torch(com_pos, dtype=wp.vec3f), + ) + + # --- Test com_pos_b -> body_com_pos_b --- + assert torch.allclose( + wp.to_torch(articulation_data.com_pos_b), + wp.to_torch(articulation_data.body_com_pos_b), + atol=1e-6, + ) + + # --- Test com_quat_b -> body_com_quat_b --- + assert torch.allclose( + wp.to_torch(articulation_data.com_quat_b), + wp.to_torch(articulation_data.body_com_quat_b), + atol=1e-6, + ) + + +class TestDeprecatedJointMiscProperties: + """Tests the deprecated joint and misc properties. + + Tests the following deprecated -> new property mappings: + - joint_limits -> joint_pos_limits + - joint_friction -> joint_friction_coeff + - applied_torque -> applied_effort + - computed_torque -> computed_effort + - joint_dynamic_friction -> joint_dynamic_friction_coeff + - joint_effort_target -> actuator_effort_target + - joint_viscous_friction -> joint_viscous_friction_coeff + - joint_velocity_limits -> joint_vel_limits + + Note: fixed_tendon_limit -> fixed_tendon_pos_limits is tested separately + as it raises NotImplementedError. + """ + + def _setup_method( + self, num_instances: int, num_joints: int, device: str + ) -> tuple[ArticulationData, MockNewtonArticulationView]: + mock_view = MockNewtonArticulationView(num_instances, 1, num_joints, device) + mock_view.set_mock_data() + + articulation_data = ArticulationData( + mock_view, + device, + ) + return articulation_data, mock_view + + @pytest.mark.parametrize("num_instances", [1, 2]) + @pytest.mark.parametrize("num_joints", [1, 6]) + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_deprecated_joint_properties(self, mock_newton_manager, num_instances: int, num_joints: int, device: str): + """Test that deprecated joint properties match their replacements.""" + articulation_data, _ = self._setup_method(num_instances, num_joints, device) + + # --- Test joint_limits -> joint_pos_limits --- + assert torch.allclose( + wp.to_torch(articulation_data.joint_limits), + wp.to_torch(articulation_data.joint_pos_limits), + atol=1e-6, + ) + + # --- Test joint_friction -> joint_friction_coeff --- + assert torch.allclose( + wp.to_torch(articulation_data.joint_friction), + wp.to_torch(articulation_data.joint_friction_coeff), + atol=1e-6, + ) + + # --- Test applied_torque -> applied_effort --- + assert torch.allclose( + wp.to_torch(articulation_data.applied_torque), + wp.to_torch(articulation_data.applied_effort), + atol=1e-6, + ) + + # --- Test computed_torque -> computed_effort --- + assert torch.allclose( + wp.to_torch(articulation_data.computed_torque), + wp.to_torch(articulation_data.computed_effort), + atol=1e-6, + ) + + # --- Test joint_dynamic_friction -> joint_dynamic_friction_coeff --- + assert torch.allclose( + wp.to_torch(articulation_data.joint_dynamic_friction), + wp.to_torch(articulation_data.joint_dynamic_friction_coeff), + atol=1e-6, + ) + + # --- Test joint_effort_target -> actuator_effort_target --- + assert torch.allclose( + wp.to_torch(articulation_data.joint_effort_target), + wp.to_torch(articulation_data.actuator_effort_target), + atol=1e-6, + ) + + # --- Test joint_viscous_friction -> joint_viscous_friction_coeff --- + assert torch.allclose( + wp.to_torch(articulation_data.joint_viscous_friction), + wp.to_torch(articulation_data.joint_viscous_friction_coeff), + atol=1e-6, + ) + + # --- Test joint_velocity_limits -> joint_vel_limits --- + assert torch.allclose( + wp.to_torch(articulation_data.joint_velocity_limits), + wp.to_torch(articulation_data.joint_vel_limits), + atol=1e-6, + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) + def test_fixed_tendon_limit_not_implemented(self, mock_newton_manager, device: str): + """Test that fixed_tendon_limit raises NotImplementedError (same as fixed_tendon_pos_limits).""" + articulation_data, _ = self._setup_method(1, 1, device) + + with pytest.raises(NotImplementedError): + _ = articulation_data.fixed_tendon_limit + + +## +# Main +## + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/source/isaaclab_newton/test/conftest.py b/source/isaaclab_newton/test/conftest.py new file mode 100644 index 00000000000..2a165ea4326 --- /dev/null +++ b/source/isaaclab_newton/test/conftest.py @@ -0,0 +1,17 @@ +# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Pytest configuration for isaaclab_newton tests. + +This conftest.py adds test subdirectories to the Python path so that local +helper modules (like mock_interface.py) can be imported by test files. +""" + +import sys +from pathlib import Path + +# Add test subdirectories to path so local modules can be imported +test_dir = Path(__file__).parent +sys.path.insert(0, str(test_dir / "assets" / "articulation_data")) diff --git a/source/isaaclab_newton/tools/compare_benchmarks.py b/source/isaaclab_newton/tools/compare_benchmarks.py new file mode 100644 index 00000000000..4a36485380a --- /dev/null +++ b/source/isaaclab_newton/tools/compare_benchmarks.py @@ -0,0 +1,543 @@ +# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Benchmark comparison tool for ArticulationData performance analysis. + +This script compares two benchmark JSON files and identifies performance +regressions and improvements. + +Usage: + python compare_benchmarks.py baseline.json current.json [--threshold 10] + +Example: + python compare_benchmarks.py articulation_data_2025-12-01.json articulation_data_2025-12-11.json +""" + +from __future__ import annotations + +import argparse +import json +import math +import sys +from dataclasses import dataclass +from enum import Enum + + +class ChangeType(Enum): + """Type of performance change.""" + + REGRESSION = "regression" + IMPROVEMENT = "improvement" + UNCHANGED = "unchanged" + NEW = "new" + REMOVED = "removed" + + +@dataclass +class PropertyComparison: + """Comparison result for a single property.""" + + name: str + baseline_mean_us: float | None + current_mean_us: float | None + baseline_std_us: float | None + current_std_us: float | None + absolute_change_us: float | None + percent_change: float | None + change_type: ChangeType + combined_std_us: float | None = None + """Combined standard deviation: sqrt(baseline_std^2 + current_std^2).""" + sigma_change: float | None = None + """Change expressed in units of combined standard deviation.""" + + +def load_benchmark(filepath: str) -> dict: + """Load a benchmark JSON file. + + Args: + filepath: Path to the JSON file. + + Returns: + Parsed JSON data. + + Raises: + FileNotFoundError: If the file doesn't exist. + json.JSONDecodeError: If the file is not valid JSON. + """ + with open(filepath) as f: + return json.load(f) + + +def extract_results(benchmark_data: dict) -> dict[str, dict]: + """Extract results from benchmark data into a lookup dict. + + Args: + benchmark_data: Parsed benchmark JSON data. + + Returns: + Dictionary mapping property names to their result data. + """ + results = {} + for result in benchmark_data.get("results", []): + name = result.get("name") + if name: + results[name] = result + return results + + +def compare_benchmarks( + baseline_data: dict, + current_data: dict, + regression_threshold: float = 10.0, + improvement_threshold: float = 10.0, + sigma_threshold: float = 1.0, +) -> list[PropertyComparison]: + """Compare two benchmark results. + + A change is only considered significant if it exceeds BOTH: + 1. The percentage threshold (regression_threshold or improvement_threshold) + 2. The sigma threshold (change must be > sigma_threshold * combined_std) + + This prevents flagging changes that are within statistical noise. + + Args: + baseline_data: Baseline benchmark JSON data. + current_data: Current benchmark JSON data. + regression_threshold: Percent increase to consider a regression. + improvement_threshold: Percent decrease to consider an improvement. + sigma_threshold: Number of combined standard deviations the change must + exceed to be considered significant. Default 1.0 means changes within + 1 std dev of combined uncertainty are considered unchanged. + + Returns: + List of PropertyComparison objects. + """ + baseline_results = extract_results(baseline_data) + current_results = extract_results(current_data) + + all_properties = set(baseline_results.keys()) | set(current_results.keys()) + comparisons = [] + + for prop_name in sorted(all_properties): + baseline = baseline_results.get(prop_name) + current = current_results.get(prop_name) + + if baseline is None: + # New property in current (current must exist since prop is in all_properties) + assert current is not None + comparisons.append( + PropertyComparison( + name=prop_name, + baseline_mean_us=None, + current_mean_us=current["mean_time_us"], + baseline_std_us=None, + current_std_us=current["std_time_us"], + absolute_change_us=None, + percent_change=None, + change_type=ChangeType.NEW, + ) + ) + elif current is None: + # Property removed in current + comparisons.append( + PropertyComparison( + name=prop_name, + baseline_mean_us=baseline["mean_time_us"], + current_mean_us=None, + baseline_std_us=baseline["std_time_us"], + current_std_us=None, + absolute_change_us=None, + percent_change=None, + change_type=ChangeType.REMOVED, + ) + ) + else: + # Both exist - compare + baseline_mean = baseline["mean_time_us"] + current_mean = current["mean_time_us"] + baseline_std = baseline["std_time_us"] + current_std = current["std_time_us"] + absolute_change = current_mean - baseline_mean + + # Compute combined standard deviation + combined_std = math.sqrt(baseline_std**2 + current_std**2) + + # Compute change in units of sigma + if combined_std > 0: + sigma_change = absolute_change / combined_std + else: + sigma_change = float("inf") if absolute_change != 0 else 0.0 + + if baseline_mean > 0: + percent_change = (absolute_change / baseline_mean) * 100 + else: + percent_change = 0.0 if current_mean == 0 else float("inf") + + # Determine change type: + # A change is significant only if it exceeds BOTH the percentage threshold + # AND the sigma threshold (i.e., the change is outside statistical noise) + is_statistically_significant = abs(sigma_change) > sigma_threshold + + if percent_change > regression_threshold and is_statistically_significant: + change_type = ChangeType.REGRESSION + elif percent_change < -improvement_threshold and is_statistically_significant: + change_type = ChangeType.IMPROVEMENT + else: + change_type = ChangeType.UNCHANGED + + comparisons.append( + PropertyComparison( + name=prop_name, + baseline_mean_us=baseline_mean, + current_mean_us=current_mean, + baseline_std_us=baseline_std, + current_std_us=current_std, + absolute_change_us=absolute_change, + percent_change=percent_change, + change_type=change_type, + combined_std_us=combined_std, + sigma_change=sigma_change, + ) + ) + + return comparisons + + +def print_metadata_comparison(baseline_data: dict, current_data: dict): + """Print comparison of metadata between two benchmarks. + + Args: + baseline_data: Baseline benchmark JSON data. + current_data: Current benchmark JSON data. + """ + print("\n" + "=" * 115) + print("BENCHMARK COMPARISON") + print("=" * 115) + + baseline_meta = baseline_data.get("metadata", {}) + current_meta = current_data.get("metadata", {}) + + # Repository info + baseline_repo = baseline_meta.get("repository", {}) + current_repo = current_meta.get("repository", {}) + + print(f"\n{'':30} {'BASELINE':>30} {'CURRENT':>30}") + print("-" * 100) + print(f"{'Timestamp:':<30} {baseline_meta.get('timestamp', 'N/A'):>30} {current_meta.get('timestamp', 'N/A'):>30}") + print( + f"{'Commit:':<30} {baseline_repo.get('commit_hash_short', 'N/A'):>30} {current_repo.get('commit_hash_short', 'N/A'):>30}" + ) + print(f"{'Branch:':<30} {baseline_repo.get('branch', 'N/A'):>30} {current_repo.get('branch', 'N/A'):>30}") + + # Config + baseline_config = baseline_meta.get("config", {}) + current_config = current_meta.get("config", {}) + + print(f"\n{'Configuration:':<30}") + print( + f"{' Iterations:':<30} {baseline_config.get('num_iterations', 'N/A'):>30} {current_config.get('num_iterations', 'N/A'):>30}" + ) + print( + f"{' Instances:':<30} {baseline_config.get('num_instances', 'N/A'):>30} {current_config.get('num_instances', 'N/A'):>30}" + ) + print( + f"{' Bodies:':<30} {baseline_config.get('num_bodies', 'N/A'):>30} {current_config.get('num_bodies', 'N/A'):>30}" + ) + print( + f"{' Joints:':<30} {baseline_config.get('num_joints', 'N/A'):>30} {current_config.get('num_joints', 'N/A'):>30}" + ) + + # Hardware + baseline_hw = baseline_meta.get("hardware", {}) + current_hw = current_meta.get("hardware", {}) + + baseline_gpu = baseline_hw.get("gpu", {}) + current_gpu = current_hw.get("gpu", {}) + + baseline_gpu_name = "N/A" + current_gpu_name = "N/A" + if baseline_gpu.get("devices"): + baseline_gpu_name = baseline_gpu["devices"][0].get("name", "N/A") + if current_gpu.get("devices"): + current_gpu_name = current_gpu["devices"][0].get("name", "N/A") + + print(f"\n{'Hardware:':<30}") + print(f"{' GPU:':<30} {baseline_gpu_name:>30} {current_gpu_name:>30}") + + +def print_comparison_results( + comparisons: list[PropertyComparison], + show_unchanged: bool = False, +): + """Print comparison results. + + Args: + comparisons: List of property comparisons. + show_unchanged: Whether to show unchanged properties. + """ + # Separate by change type + regressions = [c for c in comparisons if c.change_type == ChangeType.REGRESSION] + improvements = [c for c in comparisons if c.change_type == ChangeType.IMPROVEMENT] + unchanged = [c for c in comparisons if c.change_type == ChangeType.UNCHANGED] + new_props = [c for c in comparisons if c.change_type == ChangeType.NEW] + removed_props = [c for c in comparisons if c.change_type == ChangeType.REMOVED] + + # Sort regressions by percent change (worst first) + regressions.sort(key=lambda x: x.percent_change or 0, reverse=True) + # Sort improvements by percent change (best first) + improvements.sort(key=lambda x: x.percent_change or 0) + + # Print regressions + if regressions: + print("\n" + "=" * 115) + print(f"🔴 REGRESSIONS ({len(regressions)} properties)") + print("=" * 115) + print( + f"\n{'Property':<35} {'Baseline (µs)':>12} {'Current (µs)':>12} {'Change':>12} {'% Change':>10} {'σ Change':>10}" + ) + print("-" * 115) + for comp in regressions: + change_str = f"+{comp.absolute_change_us:.2f}" if comp.absolute_change_us else "N/A" + pct_str = f"+{comp.percent_change:.1f}%" if comp.percent_change else "N/A" + sigma_str = f"+{comp.sigma_change:.1f}σ" if comp.sigma_change else "N/A" + print( + f"{comp.name:<35} {comp.baseline_mean_us:>12.2f} {comp.current_mean_us:>12.2f} " + f"{change_str:>12} {pct_str:>10} {sigma_str:>10}" + ) + + # Print improvements + if improvements: + print("\n" + "=" * 115) + print(f"🟢 IMPROVEMENTS ({len(improvements)} properties)") + print("=" * 115) + print( + f"\n{'Property':<35} {'Baseline (µs)':>12} {'Current (µs)':>12} {'Change':>12} {'% Change':>10} {'σ Change':>10}" + ) + print("-" * 115) + for comp in improvements: + change_str = f"{comp.absolute_change_us:.2f}" if comp.absolute_change_us else "N/A" + pct_str = f"{comp.percent_change:.1f}%" if comp.percent_change else "N/A" + sigma_str = f"{comp.sigma_change:.1f}σ" if comp.sigma_change else "N/A" + print( + f"{comp.name:<35} {comp.baseline_mean_us:>12.2f} {comp.current_mean_us:>12.2f} " + f"{change_str:>12} {pct_str:>10} {sigma_str:>10}" + ) + + # Print unchanged (if requested) + if show_unchanged and unchanged: + print("\n" + "=" * 115) + print(f"⚪ UNCHANGED ({len(unchanged)} properties)") + print("=" * 115) + print( + f"\n{'Property':<35} {'Baseline (µs)':>12} {'Current (µs)':>12} {'Change':>12} {'% Change':>10} {'σ Change':>10}" + ) + print("-" * 115) + for comp in unchanged: + change_str = f"{comp.absolute_change_us:+.2f}" if comp.absolute_change_us else "N/A" + pct_str = f"{comp.percent_change:+.1f}%" if comp.percent_change else "N/A" + sigma_str = f"{comp.sigma_change:+.1f}σ" if comp.sigma_change else "N/A" + print( + f"{comp.name:<35} {comp.baseline_mean_us:>12.2f} {comp.current_mean_us:>12.2f} " + f"{change_str:>12} {pct_str:>10} {sigma_str:>10}" + ) + + # Print new properties + if new_props: + print("\n" + "=" * 115) + print(f"🆕 NEW PROPERTIES ({len(new_props)} properties)") + print("=" * 115) + for comp in new_props: + print(f" - {comp.name}: {comp.current_mean_us:.2f} µs") + + # Print removed properties + if removed_props: + print("\n" + "=" * 115) + print(f"❌ REMOVED PROPERTIES ({len(removed_props)} properties)") + print("=" * 115) + for comp in removed_props: + print(f" - {comp.name}: was {comp.baseline_mean_us:.2f} µs") + + # Print summary + print("\n" + "=" * 115) + print("SUMMARY") + print("=" * 115) + total = len(comparisons) + print(f"\n Total properties compared: {total}") + print(f" 🔴 Regressions: {len(regressions):>4} ({100 * len(regressions) / total:.1f}%)") + print(f" 🟢 Improvements: {len(improvements):>4} ({100 * len(improvements) / total:.1f}%)") + print(f" ⚪ Unchanged: {len(unchanged):>4} ({100 * len(unchanged) / total:.1f}%)") + if new_props: + print(f" 🆕 New: {len(new_props):>4}") + if removed_props: + print(f" ❌ Removed: {len(removed_props):>4}") + + # Overall verdict + print("\n" + "-" * 115) + if regressions: + print(f" ⚠️ VERDICT: {len(regressions)} regression(s) detected!") + return 1 # Exit code for CI + else: + print(" ✅ VERDICT: No regressions detected.") + return 0 + + +def export_comparison_json( + comparisons: list[PropertyComparison], + baseline_data: dict, + current_data: dict, + filename: str, +): + """Export comparison results to JSON. + + Args: + comparisons: List of property comparisons. + baseline_data: Baseline benchmark data. + current_data: Current benchmark data. + filename: Output filename. + """ + output = { + "baseline": { + "file": baseline_data.get("metadata", {}).get("timestamp", "Unknown"), + "commit": baseline_data.get("metadata", {}).get("repository", {}).get("commit_hash_short", "Unknown"), + }, + "current": { + "file": current_data.get("metadata", {}).get("timestamp", "Unknown"), + "commit": current_data.get("metadata", {}).get("repository", {}).get("commit_hash_short", "Unknown"), + }, + "regressions": [], + "improvements": [], + "unchanged": [], + "new": [], + "removed": [], + } + + for comp in comparisons: + entry = { + "name": comp.name, + "baseline_mean_us": comp.baseline_mean_us, + "current_mean_us": comp.current_mean_us, + "baseline_std_us": comp.baseline_std_us, + "current_std_us": comp.current_std_us, + "absolute_change_us": comp.absolute_change_us, + "percent_change": comp.percent_change, + "combined_std_us": comp.combined_std_us, + "sigma_change": comp.sigma_change, + } + + if comp.change_type == ChangeType.REGRESSION: + output["regressions"].append(entry) + elif comp.change_type == ChangeType.IMPROVEMENT: + output["improvements"].append(entry) + elif comp.change_type == ChangeType.UNCHANGED: + output["unchanged"].append(entry) + elif comp.change_type == ChangeType.NEW: + output["new"].append(entry) + elif comp.change_type == ChangeType.REMOVED: + output["removed"].append(entry) + + with open(filename, "w") as f: + json.dump(output, f, indent=2) + + print(f"\nComparison exported to {filename}") + + +def main(): + """Main entry point.""" + parser = argparse.ArgumentParser( + description="Compare two ArticulationData benchmark JSON files and find regressions.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "baseline", + type=str, + help="Path to baseline benchmark JSON file.", + ) + parser.add_argument( + "current", + type=str, + help="Path to current benchmark JSON file.", + ) + parser.add_argument( + "--regression-threshold", + "-r", + type=float, + default=10.0, + help="Percent increase threshold to consider a regression.", + ) + parser.add_argument( + "--improvement-threshold", + "-i", + type=float, + default=10.0, + help="Percent decrease threshold to consider an improvement.", + ) + parser.add_argument( + "--sigma", + "-s", + type=float, + default=1.0, + help=( + "Number of standard deviations the change must exceed to be significant. " + "Changes within this many std devs of combined uncertainty are considered noise." + ), + ) + parser.add_argument( + "--show-unchanged", + "-u", + action="store_true", + help="Show unchanged properties in output.", + ) + parser.add_argument( + "--export", + "-e", + type=str, + default=None, + help="Export comparison results to JSON file.", + ) + parser.add_argument( + "--ci", + action="store_true", + help="CI mode: exit with code 1 if regressions are found.", + ) + + args = parser.parse_args() + + # Load benchmark files + try: + baseline_data = load_benchmark(args.baseline) + current_data = load_benchmark(args.current) + except FileNotFoundError as e: + print(f"Error: {e}") + sys.exit(1) + except json.JSONDecodeError as e: + print(f"Error parsing JSON: {e}") + sys.exit(1) + + # Print metadata comparison + print_metadata_comparison(baseline_data, current_data) + + # Compare benchmarks + comparisons = compare_benchmarks( + baseline_data, + current_data, + regression_threshold=args.regression_threshold, + improvement_threshold=args.improvement_threshold, + sigma_threshold=args.sigma, + ) + + # Print results + exit_code = print_comparison_results(comparisons, show_unchanged=args.show_unchanged) + + # Export if requested + if args.export: + export_comparison_json(comparisons, baseline_data, current_data, args.export) + + # Exit with appropriate code in CI mode + if args.ci: + sys.exit(exit_code) + + +if __name__ == "__main__": + main()