Skip to content

Commit 2b7b9bc

Browse files
committed
Leverage GsTaichi zero-copy in data accessors.
1 parent e7558e4 commit 2b7b9bc

File tree

8 files changed

+148
-79
lines changed

8 files changed

+148
-79
lines changed

genesis/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
backend: gs_backend | None = None
4242
use_ndarray: bool | None = None
4343
use_fastcache: bool | None = None
44+
use_zerocopy: bool | None = None
4445
EPS: float | None = None
4546

4647

@@ -117,7 +118,7 @@ def init(
117118
backend = gs_backend.cpu
118119

119120
# Configure GsTaichi fast cache and array type
120-
global use_ndarray, use_fastcache
121+
global use_ndarray, use_fastcache, use_zerocopy
121122
is_ndarray_disabled = (os.environ.get("GS_ENABLE_NDARRAY") or ("0" if sys.platform == "darwin" else "1")) == "0"
122123
if use_ndarray is None:
123124
_use_ndarray = not (is_ndarray_disabled or performance_mode)
@@ -136,6 +137,15 @@ def init(
136137
raise_exception("Genesis previous initialized. GsTaichi fast cache mode cannot be disabled anymore.")
137138
use_ndarray, use_fastcache = _use_ndarray, _use_fastcache
138139

140+
# Unlike dynamic vs static array mode, and fastcache, zero-copy can be toggle on/off between init without issue
141+
_use_zerocopy = int(os.environ["GS_ENABLE_ZEROCOPY"]) if "GS_ENABLE_ZEROCOPY" in os.environ else None
142+
if use_ndarray and backend in (gs_backend.cpu, gs_backend.cuda):
143+
if _use_zerocopy is None:
144+
_use_zerocopy = True
145+
elif _use_zerocopy:
146+
raise_exception(f"Zero-copy only support by GsTaichi dynamic array mode on CPU and CUDA backend.")
147+
use_zerocopy = _use_zerocopy
148+
139149
# Define the right dtypes in accordance with selected backend and precision
140150
global ti_float, np_float, tc_float
141151
if precision == "32":

genesis/engine/entities/rigid_entity/rigid_entity.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,7 @@ def _build(self):
560560

561561
self._n_qs = self.n_qs
562562
self._n_dofs = self.n_dofs
563+
self._n_geoms = self.n_geoms
563564
self._is_built = True
564565

565566
verts_start = 0
@@ -575,6 +576,8 @@ def _build(self):
575576
self._free_verts_idx_local = torch.cat(free_verts_idx_local)
576577
if fixed_verts_idx_local:
577578
self._fixed_verts_idx_local = torch.cat(fixed_verts_idx_local)
579+
self._n_free_verts = len(self._free_verts_idx_local)
580+
self._n_fixed_verts = len(self._fixed_verts_idx_local)
578581

579582
self._geoms = self.geoms
580583
self._vgeoms = self.vgeoms
@@ -2015,23 +2018,36 @@ def get_verts(self):
20152018
verts : torch.Tensor, shape (n_envs, n_verts, 3)
20162019
The vertices of the entity.
20172020
"""
2018-
self._solver.update_verts_for_geoms(range(self.geom_start, self.geom_end))
2021+
self._solver.update_verts_for_geoms(slice(self.geom_start, self.geom_end))
20192022

2020-
tensor = torch.empty((self._solver._B, self.n_verts, 3), dtype=gs.tc_float, device=gs.device)
2021-
has_fixed_verts, has_free_vertices = len(self._fixed_verts_idx_local) > 0, len(self._free_verts_idx_local) > 0
2022-
if has_fixed_verts:
2023-
_kernel_get_fixed_verts(
2024-
tensor, self._fixed_verts_idx_local, self._fixed_verts_state_start, self._solver.fixed_verts_state
2025-
)
2026-
if has_free_vertices:
2027-
# FIXME: Get around some bug in gstaichi when using gstaichi with metal backend
2028-
must_copy = gs.backend == gs.metal and has_fixed_verts
2029-
tensor_free = torch.zeros_like(tensor) if must_copy else tensor
2030-
_kernel_get_free_verts(
2031-
tensor_free, self._free_verts_idx_local, self._free_verts_state_start, self._solver.free_verts_state
2032-
)
2033-
if must_copy:
2034-
tensor += tensor_free
2023+
n_fixed_verts, n_free_vertices = self._n_fixed_verts, self._n_free_verts
2024+
tensor = torch.empty((self._solver._B, n_fixed_verts + n_free_vertices, 3), dtype=gs.tc_float, device=gs.device)
2025+
2026+
if n_fixed_verts > 0:
2027+
if gs.use_zerocopy:
2028+
fixed_verts_state = ti_to_torch(self._solver.fixed_verts_state.pos)
2029+
tensor[:, self._fixed_verts_idx_local] = fixed_verts_state[
2030+
self._fixed_verts_state_start : self._fixed_verts_state_start + n_fixed_verts
2031+
]
2032+
else:
2033+
_kernel_get_fixed_verts(
2034+
tensor, self._fixed_verts_idx_local, self._fixed_verts_state_start, self._solver.fixed_verts_state
2035+
)
2036+
if n_free_vertices > 0:
2037+
if gs.use_zerocopy:
2038+
free_verts_state = ti_to_torch(self._solver.free_verts_state.pos, transpose=True)
2039+
tensor[:, self._free_verts_idx_local] = free_verts_state[
2040+
:, self._free_verts_state_start : self._fixed_verts_state_start + n_free_vertices
2041+
]
2042+
else:
2043+
# FIXME: Get around some bug in gstaichi when using gstaichi with metal backend
2044+
must_copy = gs.backend == gs.metal and n_fixed_verts > 0
2045+
tensor_free = torch.zeros_like(tensor) if must_copy else tensor
2046+
_kernel_get_free_verts(
2047+
tensor_free, self._free_verts_idx_local, self._free_verts_state_start, self._solver.free_verts_state
2048+
)
2049+
if must_copy:
2050+
tensor += tensor_free
20352051

20362052
if self._solver.n_envs == 0:
20372053
tensor = tensor.squeeze(0)
@@ -2840,6 +2856,8 @@ def n_dofs(self):
28402856
@property
28412857
def n_geoms(self):
28422858
"""The number of `RigidGeom` in the entity."""
2859+
if self._is_built:
2860+
return self._n_geoms
28432861
return sum(link.n_geoms for link in self._links)
28442862

28452863
@property

genesis/engine/entities/rigid_entity/rigid_link.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def get_verts(self):
305305
"""
306306
Get the vertices of the link's collision body (concatenation of all `link.geoms`) in the world frame.
307307
"""
308-
self._solver.update_verts_for_geoms(range(self.geom_start, self.geom_end))
308+
self._solver.update_verts_for_geoms(slice(self.geom_start, self.geom_end))
309309

310310
if self.is_fixed and not self._entity._batch_fixed_verts:
311311
tensor = torch.empty((self.n_verts, 3), dtype=gs.tc_float, device=gs.device)

genesis/engine/simulator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from .solvers.base_solver import Solver
4545

4646

47-
RATE_CHECK_ERRNO = 10
47+
RATE_CHECK_ERRNO = 10 if not gs.use_zerocopy else 1
4848

4949

5050
@ti.data_oriented

genesis/engine/solvers/rigid/rigid_solver_decomp.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -894,7 +894,7 @@ def substep(self):
894894
)
895895

896896
def check_errno(self):
897-
match kernel_get_errno(self._errno):
897+
match ti_to_torch(self._errno):
898898
case 1:
899899
max_collision_pairs_broad = self.collider._collider_info.max_collision_pairs_broad[None]
900900
gs.raise_exception(
@@ -1362,8 +1362,10 @@ def _sanitize_1D_io_variables(
13621362
_inputs_idx = torch.as_tensor(inputs_idx, dtype=gs.tc_int, device=gs.device).contiguous()
13631363
if _inputs_idx is not inputs_idx:
13641364
gs.logger.debug(ALLOCATE_TENSOR_WARNING)
1365-
_inputs_idx = torch.atleast_1d(_inputs_idx)
1366-
if _inputs_idx.ndim != 1:
1365+
_inputs_ndim = _inputs_idx.ndim
1366+
if _inputs_ndim == 0:
1367+
_inputs_idx = _inputs_idx[None]
1368+
elif _inputs_ndim > 1:
13671369
gs.raise_exception(f"Expecting 1D tensor for `{idx_name}`.")
13681370
if not ((0 <= _inputs_idx).all() or (_inputs_idx < input_size).all()):
13691371
gs.raise_exception(f"`{idx_name}` is out-of-range.")
@@ -1372,19 +1374,23 @@ def _sanitize_1D_io_variables(
13721374
_tensor = torch.as_tensor(tensor, dtype=gs.tc_float, device=gs.device).contiguous()
13731375
if _tensor is not tensor:
13741376
gs.logger.debug(ALLOCATE_TENSOR_WARNING)
1375-
tensor = _tensor.unsqueeze(0) if batched and self.n_envs and _tensor.ndim == 1 else _tensor
1376-
1377+
tensor_ndim = _tensor.ndim
1378+
if batched and self.n_envs and tensor_ndim == 1:
1379+
tensor = _tensor.unsqueeze(0)
1380+
tensor_ndim += 1
1381+
else:
1382+
tensor = _tensor
13771383
if tensor.shape[-1] != len(inputs_idx):
13781384
gs.raise_exception(f"Last dimension of the input tensor does not match length of `{idx_name}`.")
13791385

13801386
if batched:
13811387
if self.n_envs == 0:
1382-
if tensor.ndim != 1:
1388+
if tensor_ndim != 1:
13831389
gs.raise_exception(
13841390
f"Invalid input shape: {tensor.shape}. Expecting a 1D tensor for non-parallelized scene."
13851391
)
13861392
else:
1387-
if tensor.ndim == 2:
1393+
if tensor_ndim == 2:
13881394
if tensor.shape[0] != len(envs_idx):
13891395
gs.raise_exception(
13901396
f"Invalid input shape: {tensor.shape}. First dimension of the input tensor does not match "
@@ -1395,7 +1401,7 @@ def _sanitize_1D_io_variables(
13951401
f"Invalid input shape: {tensor.shape}. Expecting a 2D tensor for scene with parallelized envs."
13961402
)
13971403
else:
1398-
if tensor.ndim != 1:
1404+
if tensor_ndim != 1:
13991405
gs.raise_exception("Expecting 1D output tensor.")
14001406
return tensor, _inputs_idx, envs_idx
14011407

@@ -2382,6 +2388,11 @@ def set_drone_rpm(self, n_propellers, propellers_link_idxs, propellers_rpm, prop
23822388
)
23832389

23842390
def update_verts_for_geoms(self, geoms_idx):
2391+
if gs.use_zerocopy:
2392+
verts_updated = ti_to_torch(self.geoms_state.verts_updated, transpose=False)
2393+
if verts_updated[geoms_idx].all():
2394+
return
2395+
23852396
_, geoms_idx, _ = self._sanitize_1D_io_variables(
23862397
None, geoms_idx, self.n_geoms, None, idx_name="geoms_idx", skip_allocation=True, unsafe=False
23872398
)
@@ -6957,8 +6968,3 @@ def kernel_set_geoms_friction(
69576968
ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL))
69586969
for i_g_ in ti.ndrange(geoms_idx.shape[0]):
69596970
geoms_info.friction[geoms_idx[i_g_]] = friction[i_g_]
6960-
6961-
6962-
@ti.kernel(fastcache=gs.use_fastcache)
6963-
def kernel_get_errno(errno: array_class.V_ANNOTATION) -> ti.i32:
6964-
return errno[None]

genesis/utils/array_class.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import math
22
import dataclasses
3-
from functools import partial
3+
from functools import partial, wraps
44

55
import gstaichi as ti
66
import numpy as np
7+
import torch
78

89
import genesis as gs
910

@@ -12,12 +13,20 @@
1213
gs.raise_exception("Genesis hasn't been initialized. Did you call `gs.init()`?")
1314

1415

15-
V_ANNOTATION = ti.types.ndarray() if gs.use_ndarray else ti.template
16-
V = ti.ndarray if gs.use_ndarray else ti.field
17-
V_VEC = ti.Vector.ndarray if gs.use_ndarray else ti.Vector.field
18-
V_MAT = ti.Matrix.ndarray if gs.use_ndarray else ti.Matrix.field
16+
def build_tensor_type(tensor_type):
17+
@wraps(tensor_type)
18+
def _tensor_type_wrapper(*args, **kwargs):
19+
tensor = tensor_type(*args, **kwargs)
20+
try:
21+
# dlpack does not hold alive the original memory, so not need to track lifetime in tensor deleter
22+
tensor._tc = torch.utils.dlpack.from_dlpack(tensor.to_dlpack())
23+
except RuntimeError as e:
24+
raise RuntimeError(f"Zero-copy is not supported for backend '{gs.backend}'.") from e
25+
return tensor
1926

20-
DATA_ORIENTED = partial(dataclasses.dataclass, frozen=True) if gs.use_ndarray else ti.data_oriented
27+
if gs.use_zerocopy:
28+
return _tensor_type_wrapper
29+
return tensor_type
2130

2231

2332
def maybe_shape(shape, is_on):
@@ -59,6 +68,11 @@ def __init__(self, *args, **kwargs):
5968
return super().__new__(cls, name, bases, namespace)
6069

6170

71+
V_ANNOTATION = ti.types.ndarray() if gs.use_ndarray else ti.template
72+
V = build_tensor_type(ti.ndarray if gs.use_ndarray else ti.field)
73+
V_VEC = build_tensor_type(ti.Vector.ndarray if gs.use_ndarray else ti.Vector.field)
74+
V_MAT = build_tensor_type(ti.Matrix.ndarray if gs.use_ndarray else ti.Matrix.field)
75+
DATA_ORIENTED = partial(dataclasses.dataclass, frozen=True) if gs.use_ndarray else ti.data_oriented
6276
BASE_METACLASS = type if gs.use_ndarray else AutoInitMeta
6377

6478

genesis/utils/misc.py

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -581,39 +581,55 @@ def ti_to_python(
581581
# Get metadata
582582
ti_data_meta = _get_ti_metadata(value)
583583

584-
# Extract value as a whole.
585-
# Note that this is usually much faster than using a custom kernel to extract a slice.
586-
# The implementation is based on `taichi.lang.(ScalarField | MatrixField).to_torch`.
587-
is_metal = gs.device.type == "mps"
588-
out_dtype = _to_torch_type_fast(ti_data_meta.dtype) if to_torch else _to_numpy_type_fast(ti_data_meta.dtype)
589-
data_type = type(value)
590-
if issubclass(data_type, (ti.ScalarField, ti.ScalarNdarray)):
591-
if to_torch:
592-
out = torch.zeros(ti_data_meta.shape, dtype=out_dtype, device="cpu" if is_metal else gs.device)
593-
else:
594-
out = np.zeros(ti_data_meta.shape, dtype=out_dtype)
595-
TO_EXT_ARR_FAST_MAP[data_type](value, out)
596-
elif issubclass(data_type, ti.MatrixField):
597-
as_vector = value.m == 1
598-
shape_ext = (value.n,) if as_vector else (value.n, value.m)
599-
if to_torch:
600-
out = torch.empty(ti_data_meta.shape + shape_ext, dtype=out_dtype, device="cpu" if is_metal else gs.device)
601-
else:
602-
out = np.zeros(ti_data_meta.shape + shape_ext, dtype=out_dtype)
603-
TO_EXT_ARR_FAST_MAP[data_type](value, out, as_vector)
604-
elif issubclass(data_type, (ti.VectorNdarray, ti.MatrixNdarray)):
605-
layout_is_aos = 1
606-
as_vector = issubclass(data_type, ti.VectorNdarray)
607-
shape_ext = (value.n,) if as_vector else (value.n, value.m)
608-
if to_torch:
609-
out = torch.empty(ti_data_meta.shape + shape_ext, dtype=out_dtype, device="cpu" if is_metal else gs.device)
584+
use_zerocopy = gs.use_zerocopy
585+
if gs.use_zerocopy:
586+
# Leverage zero-copy if enabled
587+
try:
588+
out = value._tc
589+
if not to_torch:
590+
out = tensor_to_array(out)
591+
except AttributeError:
592+
gs.logger.debug("Zezo-copy memory sharing not available for this tensor. Falling back to copy mode.")
593+
use_zerocopy = False
594+
595+
if not use_zerocopy:
596+
# Extract value as a whole.
597+
# Note that this is usually much faster than using a custom kernel to extract a slice.
598+
# The implementation is based on `taichi.lang.(ScalarField | MatrixField).to_torch`.
599+
is_metal = gs.device.type == "mps"
600+
out_dtype = _to_torch_type_fast(ti_data_meta.dtype) if to_torch else _to_numpy_type_fast(ti_data_meta.dtype)
601+
data_type = type(value)
602+
if issubclass(data_type, (ti.ScalarField, ti.ScalarNdarray)):
603+
if to_torch:
604+
out = torch.zeros(ti_data_meta.shape, dtype=out_dtype, device="cpu" if is_metal else gs.device)
605+
else:
606+
out = np.zeros(ti_data_meta.shape, dtype=out_dtype)
607+
TO_EXT_ARR_FAST_MAP[data_type](value, out)
608+
elif issubclass(data_type, ti.MatrixField):
609+
as_vector = value.m == 1
610+
shape_ext = (value.n,) if as_vector else (value.n, value.m)
611+
if to_torch:
612+
out = torch.empty(
613+
ti_data_meta.shape + shape_ext, dtype=out_dtype, device="cpu" if is_metal else gs.device
614+
)
615+
else:
616+
out = np.zeros(ti_data_meta.shape + shape_ext, dtype=out_dtype)
617+
TO_EXT_ARR_FAST_MAP[data_type](value, out, as_vector)
618+
elif issubclass(data_type, (ti.VectorNdarray, ti.MatrixNdarray)):
619+
layout_is_aos = 1
620+
as_vector = issubclass(data_type, ti.VectorNdarray)
621+
shape_ext = (value.n,) if as_vector else (value.n, value.m)
622+
if to_torch:
623+
out = torch.empty(
624+
ti_data_meta.shape + shape_ext, dtype=out_dtype, device="cpu" if is_metal else gs.device
625+
)
626+
else:
627+
out = np.zeros(ti_data_meta.shape + shape_ext, dtype=out_dtype)
628+
TO_EXT_ARR_FAST_MAP[ti.MatrixNdarray](value, out, layout_is_aos, as_vector)
610629
else:
611-
out = np.zeros(ti_data_meta.shape + shape_ext, dtype=out_dtype)
612-
TO_EXT_ARR_FAST_MAP[ti.MatrixNdarray](value, out, layout_is_aos, as_vector)
613-
else:
614-
gs.raise_exception(f"Unsupported type '{type(value)}'.")
615-
if to_torch and is_metal:
616-
out = out.to(gs.device)
630+
gs.raise_exception(f"Unsupported type '{type(value)}'.")
631+
if to_torch and is_metal:
632+
out = out.to(gs.device)
617633

618634
# Transpose if necessary and requested.
619635
# Note that it is worth transposing here before slicing, as it preserve row-major memory alignment in case of
@@ -645,7 +661,7 @@ def extract_slice(
645661
"""
646662
# Make sure that the user-arguments are valid if requested
647663
if not unsafe:
648-
if value.ndim == 1 and col_mask is not None:
664+
if col_mask is not None and value.ndim == 1:
649665
gs.raise_exception("Cannot specify column mask for 1D tensor.")
650666
for i, mask in enumerate((row_mask, col_mask)):
651667
if mask is None or isinstance(mask, slice):
@@ -739,6 +755,8 @@ def ti_to_torch(
739755
unsafe (bool, optional): Whether to skip validity check of the masks.
740756
"""
741757
tensor = ti_to_python(value, transpose, to_torch=True)
758+
if row_mask is None and col_mask is None:
759+
return tensor
742760

743761
ti_data_meta = _get_ti_metadata(value)
744762
if len(ti_data_meta.shape) < 2:
@@ -771,6 +789,8 @@ def ti_to_numpy(
771789
unsafe (bool, optional): Whether to skip validity check of the masks.
772790
"""
773791
tensor = ti_to_python(value, transpose, to_torch=False)
792+
if row_mask is None and col_mask is None:
793+
return tensor
774794

775795
ti_data_meta = _get_ti_metadata(value)
776796
if len(ti_data_meta.shape) < 2:

0 commit comments

Comments
 (0)