diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py index 7812bc4dab..014a8f26d1 100644 --- a/genesis/engine/entities/rigid_entity/rigid_entity.py +++ b/genesis/engine/entities/rigid_entity/rigid_entity.py @@ -2443,6 +2443,46 @@ def get_qpos(self, qs_idx_local=None, envs_idx=None, *, unsafe=False): qs_idx = self._get_idx(qs_idx_local, self.n_qs, self._q_start, unsafe=True) return self._solver.get_qpos(qs_idx, envs_idx, unsafe=unsafe) + @gs.assert_built + def get_joints_anchor_pos(self, joints_idx_local=None, envs_idx=None, *, unsafe=False): + """ + Returns anchor position of the entity's joints. This is the position of the joint in the world frame. + + Parameters + ---------- + joints_idx_local : None | array_like, optional + The indices of the dofs to get. If None, all dofs will be returned. Note that here this uses the local `q_idx`, not the scene-level one. Defaults to None. + envs_idx : None | array_like, optional + The indices of the environments. If None, all environments will be considered. Defaults to None. + + Returns + ------- + anchor_pos : torch.Tensor, shape (n_joints, 3) or (n_envs, n_joints, 3) + The anchor position of the entity's joints. + """ + joints_idx = self._get_idx(joints_idx_local, self.n_joints, self._joint_start, unsafe=True) + return self._solver.get_joints_anchor_pos(joints_idx, envs_idx, unsafe=unsafe).squeeze(-2) + + @gs.assert_built + def get_joints_anchor_axis(self, joints_idx_local=None, envs_idx=None, *, unsafe=False): + """ + Returns anchor axis of the entity's joints represented as a unit vector. + + Parameters + ---------- + joints_idx_local : None | array_like, optional + The indices of the dofs to get. If None, all dofs will be returned. Note that here this uses the local `q_idx`, not the scene-level one. Defaults to None. + envs_idx : None | array_like, optional + The indices of the environments. If None, all environments will be considered. Defaults to None. + + Returns + ------- + anchor_axis : torch.Tensor, shape (n_joints, 3) or (n_envs, n_joints, 3) + The anchor axis of the entity's joints. + """ + joints_idx = self._get_idx(joints_idx_local, self.n_joints, self._joint_start, unsafe=True) + return self._solver.get_joints_anchor_axis(joints_idx, envs_idx, unsafe=unsafe).squeeze(-2) + @gs.assert_built def get_dofs_control_force(self, dofs_idx_local=None, envs_idx=None, *, unsafe=False): """ diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index 015d36761d..33d01e9180 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -4708,6 +4708,14 @@ def get_qpos(self, qs_idx=None, envs_idx=None, *, unsafe=False): tensor = ti_field_to_torch(self.qpos, envs_idx, qs_idx, transpose=True, unsafe=unsafe) return tensor.squeeze(0) if self.n_envs == 0 else tensor + def get_joints_anchor_pos(self, joints_idx=None, envs_idx=None, *, unsafe=False): + tensor = ti_field_to_torch(self.joints_state.xanchor, envs_idx, joints_idx, transpose=True, unsafe=unsafe) + return tensor.squeeze(0) if self.n_envs == 0 else tensor + + def get_joints_anchor_axis(self, joints_idx=None, envs_idx=None, *, unsafe=False): + tensor = ti_field_to_torch(self.joints_state.xaxis, envs_idx, joints_idx, transpose=True, unsafe=unsafe) + return tensor.squeeze(0) if self.n_envs == 0 else tensor + def get_dofs_control_force(self, dofs_idx=None, envs_idx=None, *, unsafe=False): _tensor, dofs_idx, envs_idx = self._sanitize_1D_io_variables( None, dofs_idx, self.n_dofs, envs_idx, unsafe=unsafe