Skip to content

Commit 9c067b1

Browse files
committed
Fix camera.set_pose()
1 parent fc557be commit 9c067b1

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

genesis/vis/camera.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -594,36 +594,42 @@ def set_pose(self, transform=None, pos=None, lookat=None, up=None, env_idx=None)
594594
The environment indices. If not provided, the camera pose will be set for all environments.
595595
"""
596596
# Check that all provided inputs are of the same type (either all torch.Tensor or all numpy.ndarray)
597-
if transform is not None:
598-
transform = torch.as_tensor(transform, dtype=gs.tc_float, device=gs.device)
599-
if pos is not None:
600-
pos = torch.as_tensor(pos, dtype=gs.tc_float, device=gs.device)
601-
if lookat is not None:
602-
lookat = torch.as_tensor(lookat, dtype=gs.tc_float, device=gs.device)
603-
if up is not None:
604-
up = torch.as_tensor(up, dtype=gs.tc_float, device=gs.device)
605-
606-
# Expand to n_envs
607597
n_envs = max(self._visualizer.scene.n_envs, 1)
608-
if env_idx is None:
609-
env_idx = torch.arange(n_envs)
610598
if transform is not None:
599+
transform = torch.as_tensor(transform, dtype=gs.tc_float, device=gs.device)
611600
if transform.shape[-2:] != (4, 4):
612601
raise ValueError(f"Transform shape {transform.shape} does not match (4, 4)")
613602
if transform.ndim == 2:
614603
transform = transform.expand((n_envs, 4, 4))
604+
else:
605+
transform = self._multi_env_transform_tensor
606+
615607
if pos is not None:
608+
pos = torch.as_tensor(pos, dtype=gs.tc_float, device=gs.device)
616609
assert pos.shape[-1] == 3, f"Pos shape {pos.shape} does not match (n_envs, 3)"
617610
if pos.ndim == 1:
618611
pos = pos.expand((n_envs, 3))
612+
else:
613+
pos = self._multi_env_pos_tensor
614+
619615
if lookat is not None:
616+
lookat = torch.as_tensor(lookat, dtype=gs.tc_float, device=gs.device)
620617
assert lookat.shape[-1] == 3, f"Lookat shape {lookat.shape} does not match (n_envs, 3)"
621618
if lookat.ndim == 1:
622619
lookat = lookat.expand((n_envs, 3))
620+
else:
621+
lookat = self._multi_env_lookat_tensor
622+
623623
if up is not None:
624+
up = torch.as_tensor(up, dtype=gs.tc_float, device=gs.device)
624625
assert up.shape[-1] == 3, f"Up shape {up.shape} does not match (n_envs, 3)"
625626
if up.ndim == 1:
626627
up = up.expand((n_envs, 3))
628+
else:
629+
up = self._multi_env_up_tensor
630+
631+
if env_idx is None:
632+
env_idx = torch.arange(n_envs)
627633

628634
assert (
629635
transform is None or transform.shape[0] == env_idx.shape[0]
@@ -643,15 +649,15 @@ def set_pose(self, transform=None, pos=None, lookat=None, up=None, env_idx=None)
643649
new_lookat = self._multi_env_lookat_tensor[env_idx]
644650
new_up = self._multi_env_up_tensor[env_idx]
645651
if transform is not None:
646-
new_transform = transform
652+
new_transform = transform.clone()
647653
new_pos, new_lookat, new_up = gu.T_to_pos_lookat_up(new_transform)
648654
else:
649655
if pos is not None:
650-
new_pos = pos
656+
new_pos = pos.clone()
651657
if lookat is not None:
652-
new_lookat = lookat
658+
new_lookat = lookat.clone()
653659
if up is not None:
654-
new_up = up
660+
new_up = up.clone()
655661
new_transform = gu.pos_lookat_up_to_T(new_pos, new_lookat, new_up)
656662

657663
new_quat = _T_to_quat_for_madrona(new_transform)

0 commit comments

Comments
 (0)