@@ -594,36 +594,42 @@ def set_pose(self, transform=None, pos=None, lookat=None, up=None, env_idx=None)
594594 The environment indices. If not provided, the camera pose will be set for all environments.
595595 """
596596 # Check that all provided inputs are of the same type (either all torch.Tensor or all numpy.ndarray)
597- if transform is not None :
598- transform = torch .as_tensor (transform , dtype = gs .tc_float , device = gs .device )
599- if pos is not None :
600- pos = torch .as_tensor (pos , dtype = gs .tc_float , device = gs .device )
601- if lookat is not None :
602- lookat = torch .as_tensor (lookat , dtype = gs .tc_float , device = gs .device )
603- if up is not None :
604- up = torch .as_tensor (up , dtype = gs .tc_float , device = gs .device )
605-
606- # Expand to n_envs
607597 n_envs = max (self ._visualizer .scene .n_envs , 1 )
608- if env_idx is None :
609- env_idx = torch .arange (n_envs )
610598 if transform is not None :
599+ transform = torch .as_tensor (transform , dtype = gs .tc_float , device = gs .device )
611600 if transform .shape [- 2 :] != (4 , 4 ):
612601 raise ValueError (f"Transform shape { transform .shape } does not match (4, 4)" )
613602 if transform .ndim == 2 :
614603 transform = transform .expand ((n_envs , 4 , 4 ))
604+ else :
605+ transform = self ._multi_env_transform_tensor
606+
615607 if pos is not None :
608+ pos = torch .as_tensor (pos , dtype = gs .tc_float , device = gs .device )
616609 assert pos .shape [- 1 ] == 3 , f"Pos shape { pos .shape } does not match (n_envs, 3)"
617610 if pos .ndim == 1 :
618611 pos = pos .expand ((n_envs , 3 ))
612+ else :
613+ pos = self ._multi_env_pos_tensor
614+
619615 if lookat is not None :
616+ lookat = torch .as_tensor (lookat , dtype = gs .tc_float , device = gs .device )
620617 assert lookat .shape [- 1 ] == 3 , f"Lookat shape { lookat .shape } does not match (n_envs, 3)"
621618 if lookat .ndim == 1 :
622619 lookat = lookat .expand ((n_envs , 3 ))
620+ else :
621+ lookat = self ._multi_env_lookat_tensor
622+
623623 if up is not None :
624+ up = torch .as_tensor (up , dtype = gs .tc_float , device = gs .device )
624625 assert up .shape [- 1 ] == 3 , f"Up shape { up .shape } does not match (n_envs, 3)"
625626 if up .ndim == 1 :
626627 up = up .expand ((n_envs , 3 ))
628+ else :
629+ up = self ._multi_env_up_tensor
630+
631+ if env_idx is None :
632+ env_idx = torch .arange (n_envs )
627633
628634 assert (
629635 transform is None or transform .shape [0 ] == env_idx .shape [0 ]
@@ -643,15 +649,15 @@ def set_pose(self, transform=None, pos=None, lookat=None, up=None, env_idx=None)
643649 new_lookat = self ._multi_env_lookat_tensor [env_idx ]
644650 new_up = self ._multi_env_up_tensor [env_idx ]
645651 if transform is not None :
646- new_transform = transform
652+ new_transform = transform . clone ()
647653 new_pos , new_lookat , new_up = gu .T_to_pos_lookat_up (new_transform )
648654 else :
649655 if pos is not None :
650- new_pos = pos
656+ new_pos = pos . clone ()
651657 if lookat is not None :
652- new_lookat = lookat
658+ new_lookat = lookat . clone ()
653659 if up is not None :
654- new_up = up
660+ new_up = up . clone ()
655661 new_transform = gu .pos_lookat_up_to_T (new_pos , new_lookat , new_up )
656662
657663 new_quat = _T_to_quat_for_madrona (new_transform )
0 commit comments