Skip to content
13 changes: 13 additions & 0 deletions genesis/engine/solvers/rigid/rigid_solver_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4920,6 +4920,19 @@ def add_weld_constraint(self, link1_idx, link2_idx, envs_idx=None, *, unsafe=Fal
_, link2_idx, envs_idx = self._sanitize_1D_io_variables(
None, link2_idx, self.n_links, envs_idx, idx_name="links_idx", skip_allocation=True, unsafe=unsafe
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not cast from torch to numpy, this is irrelevant.

if torch.is_tensor(link1_idx):
link1_idx = link1_idx.cpu().numpy()
if torch.is_tensor(link2_idx):
link2_idx = link2_idx.cpu().numpy()
if envs_idx is not None and torch.is_tensor(envs_idx):
envs_idx = envs_idx.cpu().numpy()

if envs_idx is not None and envs_idx.shape[0] > 1:
if link1_idx.shape[0] == 1:
link1_idx = np.repeat(link1_idx, envs_idx.shape[0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use torch.repeat

Copy link
Collaborator

@duburcqa duburcqa Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, no need to check if it is a tensor. It has been sanitized as torch tensor already.

if link2_idx.shape[0] == 1:
link2_idx = np.repeat(link2_idx, envs_idx.shape[0])
self._kernel_add_weld_constraint(link1_idx, link2_idx, envs_idx)

@ti.kernel
Expand Down