Skip to content

Commit 663fa80

Browse files
authored
Merge branch 'main' into feature/env-wise-gravity
2 parents 88fe5fe + db81a21 commit 663fa80

28 files changed

+875
-196
lines changed

.github/workflows/linux-gpu.yml

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ jobs:
3535
3636
mkdir -p "${HOME}/.cache"
3737
38+
# Prefer idle nodes if any
39+
IDLE_NODES=$(sinfo -h -o "%N %t" | awk '$2 == "idle" {print $1}')
40+
if [[ -n "$IDLE_NODES" ]]; then
41+
NODELIST="--nodelist=$IDLE_NODES"
42+
fi
43+
3844
srun \
3945
--container-image="/mnt/data/images/genesis-v${GENESIS_IMAGE_VER}.sqsh" \
4046
--container-mounts=\
@@ -44,7 +50,7 @@ jobs:
4450
--export=\
4551
HF_TOKEN="${HF_TOKEN}",\
4652
NVIDIA_DRIVER_CAPABILITIES=all \
47-
--partition=hpc-mid --nodes=1 --gpus=1 --time="${TIMEOUT_MINUTES}" \
53+
--partition=hpc-mid ${NODELIST} --nodes=1 --time="${TIMEOUT_MINUTES}" \
4854
--job-name=${SLURM_JOB_NAME} \
4955
bash -c "
5056
pip install -e '.[dev,render]' && \
@@ -69,16 +75,16 @@ jobs:
6975
"${{ github.workspace }}":/root/workspace \
7076
--no-container-mount-home --container-workdir=/root/workspace \
7177
--export=${SLURM_ENV_VARS} \
72-
--partition=hpc-mid --exclusive --nodes=1 --gpus=1 --time="${TIMEOUT_MINUTES}" \
78+
--partition=hpc-mid --exclusive --nodes=1 --time="${TIMEOUT_MINUTES}" \
7379
--job-name=${SLURM_JOB_NAME} \
7480
bash -c "
7581
: # sudo apt install -y tmate && \
7682
tmate -S /tmp/tmate.sock new-session -d && \
7783
tmate -S /tmp/tmate.sock wait tmate-ready && \
7884
tmate -S /tmp/tmate.sock display -p '#{tmate_ssh}'
7985
pip install -e '.[dev,render]' && \
80-
pytest --print -x -m 'benchmarks' --backend gpu ./tests && \
81-
cp 'speed_test.txt' '/mnt/data/artifacts/speed_test_${SLURM_JOB_NAME}.txt'
86+
pytest --print -x -m 'benchmarks' ./tests && \
87+
cat speed_test*.txt > '/mnt/data/artifacts/speed_test_${SLURM_JOB_NAME}.txt'
8288
: # tmate -S /tmp/tmate.sock wait tmate-exit
8389
"
8490

examples/drone/hover_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def step(self, actions):
134134
self.drone.set_propellels_rpm((1 + exec_actions * 0.8) * 14468.429183500699)
135135
# update target pos
136136
if self.target is not None:
137-
self.target.set_pos(self.commands, zero_velocity=True, envs_idx=list(range(self.num_envs)))
137+
self.target.set_pos(self.commands, zero_velocity=True)
138138
self.scene.step()
139139

140140
# update buffers

genesis/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def _display_greeting(INFO_length):
291291
wave_width = max(0, min(38, wave_width))
292292
bar_width = wave_width * 2 + 9
293293
wave = ("┈┉" * wave_width)[:wave_width]
294+
global logger
294295
logger.info(f"~<╭{'─'*(bar_width)}╮>~")
295296
logger.info(f"~<│{wave}>~ ~~~~<Genesis>~~~~ ~<{wave}│>~")
296297
logger.info(f"~<╰{'─'*(bar_width)}╯>~")
@@ -314,9 +315,10 @@ def _custom_excepthook(exctype, value, tb):
314315
print("".join(traceback.format_exception(exctype, value, tb)))
315316

316317
# Logger the exception right before exit if possible
318+
global logger
317319
try:
318320
logger.error(f"{exctype.__name__}: {value}")
319-
except AttributeError:
321+
except (AttributeError, NameError):
320322
# Logger may not be configured at this point
321323
pass
322324

genesis/engine/entities/hybrid_entity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,8 @@ def _kernel_update_soft_part_mpm(self, f: ti.i32):
431431
acc = vel_d / dt_for_rigid_acc
432432
frc_vel = mass_real * acc
433433
frc_ang = (x_pos - link.COM).cross(frc_vel)
434-
self._solver_rigid.links_state[link_idx, i_b].cfrc_ext_vel += frc_vel
435-
self._solver_rigid.links_state[link_idx, i_b].cfrc_ext_ang += frc_ang
434+
self._solver_rigid.links_state[link_idx, i_b].cfrc_applied_vel += frc_vel
435+
self._solver_rigid.links_state[link_idx, i_b].cfrc_applied_ang += frc_ang
436436

437437
# rigid-to-soft coupling # NOTE: this may lead to unstable feedback loop
438438
self._solver_soft.particles[f_, i_global, i_b].vel += vel_d * self.material.soft_dv_coef

genesis/engine/entities/rigid_entity/rigid_entity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2748,7 +2748,7 @@ def get_contacts(self, with_entity=None, exclude_self_contact=False):
27482748
if self._solver.n_envs == 0:
27492749
contacts_info = {key: value[valid_mask] for key, value in contacts_info.items()}
27502750
else:
2751-
contacts_info = {key: value[:, valid_mask] for key, value in contacts_info.items()}
2751+
contacts_info["valid_mask"] = valid_mask
27522752

27532753
contacts_info["force_a"] = -contacts_info["force"]
27542754
contacts_info["force_b"] = +contacts_info["force"]

genesis/engine/scene.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
import numpy as np
44
import torch
5+
import pickle
6+
import time
7+
import taichi as ti
58

69
import genesis as gs
710
import genesis.utils.geom as gu
@@ -1057,6 +1060,69 @@ def _backward(self):
10571060
self._backward_ready = False
10581061
self._forward_ready = False
10591062

1063+
def dump_ckpt_to_numpy(self) -> dict[str, np.ndarray]:
1064+
"""
1065+
Collect every Taichi field in the **scene and its active solvers** and
1066+
return them as a flat ``{key: ndarray}`` dictionary.
1067+
1068+
Returns
1069+
-------
1070+
dict[str, np.ndarray]
1071+
Mapping ``"Class.attr[.member]" → array`` with raw field data.
1072+
"""
1073+
arrays: dict[str, np.ndarray] = {}
1074+
1075+
for name, field in self.__dict__.items():
1076+
if isinstance(field, ti.Field):
1077+
arrays[".".join((self.__class__.__name__, name))] = field.to_numpy()
1078+
1079+
for solver in self.active_solvers:
1080+
arrays.update(solver.dump_ckpt_to_numpy())
1081+
1082+
return arrays
1083+
1084+
def save_checkpoint(self, path: str | os.PathLike) -> None:
1085+
"""
1086+
Pickle the full physics state to *one* file.
1087+
1088+
Parameters
1089+
----------
1090+
path : str | os.PathLike
1091+
Destination filename.
1092+
"""
1093+
state = {
1094+
"timestamp": time.time(),
1095+
"step_index": self.t,
1096+
"arrays": self.dump_ckpt_to_numpy(),
1097+
}
1098+
with open(path, "wb") as f:
1099+
pickle.dump(state, f, protocol=pickle.HIGHEST_PROTOCOL)
1100+
1101+
def load_checkpoint(self, path: str | os.PathLike) -> None:
1102+
"""
1103+
Restore a file produced by :py:meth:`save_checkpoint`.
1104+
1105+
Parameters
1106+
----------
1107+
path : str | os.PathLike
1108+
Path to the checkpoint pickle.
1109+
"""
1110+
with open(path, "rb") as f:
1111+
state = pickle.load(f)
1112+
1113+
arrays = state["arrays"]
1114+
1115+
for name, field in self.__dict__.items():
1116+
if isinstance(field, ti.Field):
1117+
key = ".".join((self.__class__.__name__, name))
1118+
if key in arrays:
1119+
field.from_numpy(arrays[key])
1120+
1121+
for solver in self.active_solvers:
1122+
solver.load_ckpt_from_numpy(arrays)
1123+
1124+
self._t = state.get("step_index", self._t)
1125+
10601126
# ------------------------------------------------------------------------------------
10611127
# ----------------------------------- properties -------------------------------------
10621128
# ------------------------------------------------------------------------------------

genesis/engine/solvers/base_solver.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import TYPE_CHECKING
22
import numpy as np
33
import taichi as ti
4+
import torch
5+
from genesis.utils.misc import ti_field_to_torch
46

57
import genesis as gs
68
from genesis.engine.entities.base_entity import Entity
@@ -28,6 +30,53 @@ def __init__(self, scene: "Scene", sim: "Simulator", options):
2830
def _add_force_field(self, force_field):
2931
self._ffs.append(force_field)
3032

33+
def dump_ckpt_to_numpy(self) -> dict[str, np.ndarray]:
34+
arrays: dict[str, np.ndarray] = {}
35+
36+
for attr_name, field in self.__dict__.items():
37+
if not isinstance(field, ti.Field):
38+
continue
39+
40+
key_base = ".".join((self.__class__.__name__, attr_name))
41+
data = field.to_numpy()
42+
43+
# StructField → data is a dict: flatten each member
44+
if isinstance(data, dict):
45+
for sub_name, sub_arr in data.items():
46+
arrays[f"{key_base}.{sub_name}"] = (
47+
sub_arr if isinstance(sub_arr, np.ndarray) else np.asarray(sub_arr)
48+
)
49+
else:
50+
arrays[key_base] = data if isinstance(data, np.ndarray) else np.asarray(data)
51+
52+
return arrays
53+
54+
def load_ckpt_from_numpy(self, arr_dict: dict[str, np.ndarray]) -> None:
55+
for attr_name, field in self.__dict__.items():
56+
if not isinstance(field, ti.Field):
57+
continue
58+
59+
key_base = ".".join((self.__class__.__name__, attr_name))
60+
member_prefix = key_base + "."
61+
62+
# ---- StructField: gather its members -----------------------------
63+
member_items = {}
64+
for saved_key, saved_arr in arr_dict.items():
65+
if saved_key.startswith(member_prefix):
66+
sub_name = saved_key[len(member_prefix) :]
67+
member_items[sub_name] = saved_arr
68+
69+
if member_items: # we found at least one sub-member
70+
field.from_numpy(member_items)
71+
continue
72+
73+
# ---- Ordinary field ---------------------------------------------
74+
if key_base not in arr_dict:
75+
continue # nothing saved for this attribute
76+
77+
arr = arr_dict[key_base]
78+
field.from_numpy(arr)
79+
3180
# ------------------------------------------------------------------------------------
3281
# ----------------------------------- properties -------------------------------------
3382
# ------------------------------------------------------------------------------------

genesis/engine/solvers/rigid/collider_decomp.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2088,19 +2088,20 @@ def get_contacts(self, as_tensor: bool = True, to_torch: bool = True):
20882088
# Allocate output buffer
20892089
if to_torch:
20902090
iout = torch.full((out_size, 4), -1, dtype=gs.tc_int, device=gs.device)
2091-
fout = torch.empty((out_size, 10), dtype=gs.tc_float, device=gs.device)
2091+
fout = torch.zeros((out_size, 10), dtype=gs.tc_float, device=gs.device)
20922092
else:
20932093
iout = np.full((out_size, 4), -1, dtype=gs.np_int)
2094-
fout = np.empty((out_size, 10), dtype=gs.np_float)
2094+
fout = np.zeros((out_size, 10), dtype=gs.np_float)
20952095

20962096
# Copy contact data
2097-
self._kernel_get_contacts(as_tensor, iout, fout)
2097+
if n_contacts_max > 0:
2098+
self._kernel_get_contacts(as_tensor, iout, fout)
20982099

2099-
# Return structured view (no copy)
2100+
# Build structured view (no copy)
21002101
if as_tensor:
21012102
if self._solver.n_envs > 0:
2102-
iout = iout.reshape((n_contacts_max, n_envs, -1))
2103-
fout = fout.reshape((n_contacts_max, n_envs, -1))
2103+
iout = iout.reshape((n_envs, n_contacts_max, 4))
2104+
fout = fout.reshape((n_envs, n_contacts_max, 10))
21042105
iout_chunks = (iout[..., 0], iout[..., 1], iout[..., 2], iout[..., 3])
21052106
fout_chunks = (fout[..., 0], fout[..., 1:4], fout[..., 4:7], fout[..., 7:])
21062107
values = (*iout_chunks, *fout_chunks)
@@ -2135,10 +2136,7 @@ def get_contacts(self, as_tensor: bool = True, to_torch: bool = True):
21352136
values = (*iout_chunks, *fout_chunks)
21362137

21372138
contacts_info = dict(
2138-
zip(
2139-
("link_a", "link_b", "geom_a", "geom_b", "penetration", "position", "normal", "force"),
2140-
(value.swapaxes(0, 1) for value in values) if as_tensor and self._solver.n_envs > 0 else values,
2141-
)
2139+
zip(("link_a", "link_b", "geom_a", "geom_b", "penetration", "position", "normal", "force"), values)
21422140
)
21432141

21442142
# Cache contact information before returning
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from dataclasses import dataclass
2+
3+
from .vec3 import Vec3
4+
5+
6+
EPSILON = 1e-6
7+
8+
9+
class Ray:
10+
origin: Vec3
11+
direction: Vec3
12+
13+
def __init__(self, origin: Vec3, direction: Vec3):
14+
self.origin = origin
15+
self.direction = direction.normalized()
16+
17+
def __repr__(self) -> str:
18+
return f"Ray(origin={self.origin}, direction={self.direction})"
19+
20+
21+
@dataclass
22+
class RayHit:
23+
is_hit: bool
24+
distance: float
25+
normal: Vec3
26+
position: Vec3
27+
object_idx: int
28+
29+
30+
class Plane:
31+
normal: Vec3
32+
distance: float # distance from plane to origin along normal
33+
34+
def __init__(self, normal: Vec3, point: Vec3):
35+
self.normal = normal
36+
self.distance = -normal.dot(point)
37+
38+
def raycast(self, ray: Ray) -> RayHit:
39+
dot = ray.direction.dot(self.normal)
40+
dist = ray.origin.dot(self.normal) + self.distance
41+
42+
if -EPSILON < dot or dist < EPSILON:
43+
return RayHit(is_hit=False, distance=0, normal=Vec3.zero(), position=Vec3.zero(), object_idx=-1)
44+
45+
dist_along_ray = dist / -dot
46+
47+
return RayHit(
48+
is_hit=True,
49+
distance=dist_along_ray,
50+
normal=self.normal,
51+
position=ray.origin + ray.direction * dist_along_ray,
52+
object_idx=0
53+
)
54+

0 commit comments

Comments
 (0)