Skip to content

Commit ea4c7dd

Browse files
authored
[MISC] Avoid code duplication and improve zero-copy performance. (Genesis-Embodied-AI#2023)
1 parent 2e9899d commit ea4c7dd

File tree

5 files changed

+111
-172
lines changed

5 files changed

+111
-172
lines changed

genesis/engine/entities/rigid_entity/rigid_link.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from genesis.repr_base import RBC
1111
from genesis.utils import geom as gu
1212

13-
from genesis.utils.misc import DeprecationError
13+
from genesis.utils.misc import DeprecationError, tensor_to_array
1414

1515
from .rigid_geom import RigidGeom, RigidVisGeom, _kernel_get_free_verts, _kernel_get_fixed_verts
1616

@@ -262,7 +262,7 @@ def get_pos(self, envs_idx=None):
262262
envs_idx : int or array of int, optional
263263
The indices of the environments to get the position. If None, get the position of all environments. Default is None.
264264
"""
265-
return self._solver.get_links_pos([self._idx], envs_idx).squeeze(-2)
265+
return self._solver.get_links_pos(self._idx, envs_idx).squeeze(-2)
266266

267267
@gs.assert_built
268268
def get_quat(self, envs_idx=None):
@@ -274,7 +274,7 @@ def get_quat(self, envs_idx=None):
274274
envs_idx : int or array of int, optional
275275
The indices of the environments to get the quaternion. If None, get the quaternion of all environments. Default is None.
276276
"""
277-
return self._solver.get_links_quat([self._idx], envs_idx).squeeze(-2)
277+
return self._solver.get_links_quat(self._idx, envs_idx).squeeze(-2)
278278

279279
@gs.assert_built
280280
def get_vel(self, envs_idx=None) -> torch.Tensor:
@@ -286,7 +286,7 @@ def get_vel(self, envs_idx=None) -> torch.Tensor:
286286
envs_idx : int or array of int, optional
287287
The indices of the environments to get the linear velocity. If None, get the linear velocity of all environments. Default is None.
288288
"""
289-
return self._solver.get_links_vel([self._idx], envs_idx).squeeze(-2)
289+
return self._solver.get_links_vel(self._idx, envs_idx).squeeze(-2)
290290

291291
@gs.assert_built
292292
def get_ang(self, envs_idx=None) -> torch.Tensor:
@@ -298,7 +298,7 @@ def get_ang(self, envs_idx=None) -> torch.Tensor:
298298
envs_idx : int or array of int, optional
299299
The indices of the environments to get the angular velocity. If None, get the angular velocity of all environments. Default is None.
300300
"""
301-
return self._solver.get_links_ang([self._idx], envs_idx).squeeze(-2)
301+
return self._solver.get_links_ang(self._idx, envs_idx).squeeze(-2)
302302

303303
@gs.assert_built
304304
def get_verts(self):
@@ -545,7 +545,7 @@ def invweight(self):
545545
The invweight of the link.
546546
"""
547547
if self._invweight is None:
548-
self._invweight = self._solver.get_links_invweight([self._idx]).cpu().numpy()[..., 0, :]
548+
self._invweight = tensor_to_array(self._solver.get_links_invweight(self._idx))[..., 0, :]
549549
return self._invweight
550550

551551
@property

genesis/engine/sensors/raycaster.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -405,9 +405,7 @@ def build(self):
405405
self._shared_metadata.total_n_rays += num_rays
406406

407407
self._shared_metadata.points_to_sensor_idx = concat_with_tensor(
408-
self._shared_metadata.points_to_sensor_idx,
409-
[self._idx] * num_rays,
410-
flatten=True,
408+
self._shared_metadata.points_to_sensor_idx, [self._idx] * num_rays, flatten=True
411409
)
412410
self._shared_metadata.return_world_frame = concat_with_tensor(
413411
self._shared_metadata.return_world_frame, self._options.return_world_frame

genesis/engine/solvers/base_solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def set_gravity(self, gravity, envs_idx=None, *, unsafe=False):
5757
_kernel_set_gravity_ndarray(gravity, envs_idx, self._gravity)
5858

5959
def get_gravity(self, envs_idx=None, *, unsafe=False):
60-
tensor = ti_to_torch(self._gravity, envs_idx, transpose=True, unsafe=unsafe)
60+
tensor = ti_to_torch(self._gravity, envs_idx, transpose=True)
6161
return tensor.squeeze(0) if self.n_envs == 0 else tensor
6262

6363
def dump_ckpt_to_numpy(self) -> dict[str, np.ndarray]:

0 commit comments

Comments
 (0)