Skip to content

Commit 9249c15

Browse files
authored
[BUG FIX] Various minor cleanup and bug fixes. (Genesis-Embodied-AI#1390)
* Cleanup pre-commit. * Cleanup rigid solver. * Cleanup 'Emitter.emit_omit' implementation. * Fix some wrong dtypes. * Fix PBD solver batching. * Fix 'add_weld_constraint'.
1 parent cce7fdd commit 9249c15

File tree

14 files changed

+116
-117
lines changed

14 files changed

+116
-117
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
repos:
22
- repo: https://github.com/psf/black
3-
rev: 25.1.0 # Use the latest version or specify the version you prefer
3+
rev: 25.1.0
44
hooks:
55
- id: black
6-
args:
7-
- --line-length=120
8-
- .

examples/rigid/suction_cup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ def main():
9696

9797
# add suction / weld constraint
9898
rigid = scene.sim.rigid_solver
99-
link_cube = np.array([cube.get_link("box_baselink").idx], dtype=gs.np_int)
100-
link_franka = np.array([franka.get_link("hand").idx], dtype=gs.np_int)
99+
link_cube = cube.get_link("box_baselink").idx
100+
link_franka = franka.get_link("hand").idx
101101
rigid.add_weld_constraint(link_cube, link_franka)
102102

103103
# lift

genesis/engine/entities/emitter.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def emit(
150150
positions = gu.transform_by_trans_R(
151151
positions,
152152
pos,
153-
gu.z_up_to_R(direction) @ gu.axis_angle_to_R(np.array([0, 0, 1]), theta),
153+
gu.z_up_to_R(direction) @ gu.axis_angle_to_R(np.array([0.0, 0.0, 1.0], dtype=gs.np_float), theta),
154154
).astype(gs.np_float)
155155

156156
positions = np.tile(positions[np.newaxis], (self._sim._B, 1, 1))
@@ -206,17 +206,21 @@ def emit_omni(self, source_radius=0.1, pos=(0.5, 0.5, 1.0), speed=1.0, particle_
206206
Parameters:
207207
----------
208208
source_radius: float, optional
209-
The radius of the sphere source. Particles will be emitted from a shell with inner radius using 0.8 * source_radius and outer radius using source_radius.
209+
The radius of the sphere source. Particles will be emitted from a shell with inner radius using
210+
'0.8 * source_radius' and outer radius using source_radius.
210211
pos: array_like, shape=(3,)
211212
The center of the sphere source.
212213
speed: float
213214
The speed of the emitted particles.
214215
particle_size: float | None
215-
The size (diameter) of the emitted particles. The actual number of particles emitted is determined by the volume of the sphere source and the size of the particles. If None, the solver's particle size is used. Note that this particle size only affects computation for number of particles emitted, not the actual size of the particles in simulation and rendering.
216+
The size (diameter) of the emitted particles. The actual number of particles emitted is determined by the
217+
volume of the sphere source and the size of the particles. If None, the solver's particle size is used.
218+
Note that this particle size only affects computation for number of particles emitted, not the actual size
219+
of the particles in simulation and rendering.
216220
"""
217221
assert self._entity is not None
218222

219-
pos = np.array(pos)
223+
pos = np.asarray(pos, dtype=gs.np_float)
220224

221225
if particle_size is None:
222226
particle_size = self._solver.particle_size
@@ -227,20 +231,20 @@ def emit_omni(self, source_radius=0.1, pos=(0.5, 0.5, 1.0), speed=1.0, particle_
227231
inner_radius=source_radius * 0.4,
228232
sampler=self._entity.sampler,
229233
)
230-
231-
positions = gu.transform_by_T(positions_, gu.trans_to_T(pos)).astype(gs.np_float)
234+
positions = pos + positions_
232235

233236
if not self._solver.boundary.is_inside(positions):
234237
gs.raise_exception("Emitted particles are outside the boundary.")
235238

236-
n_particles = len(positions)
237-
dists = np.linalg.norm(positions_, axis=1, keepdims=True)
238-
positions[np.where(dists < gs.EPS)[0]] = np.array([gs.EPS, gs.EPS, gs.EPS])
239-
vels = (positions_ / dists * speed).astype(gs.np_float)
239+
dists = np.linalg.norm(positions_, axis=1)
240+
positions[dists < gs.EPS] = gs.EPS
241+
vels = (speed / (dists + gs.EPS)) * positions_
240242

243+
n_particles = len(positions)
241244
if n_particles > self._entity.n_particles:
242245
gs.logger.warning(
243-
f"Number of particles to emit ({n_particles}) at the current step is larger than the maximum number of particles ({self._entity.n_particles})."
246+
f"Number of particles to emit ({n_particles}) at the current step is larger than the maximum number "
247+
f"of particles ({self._entity.n_particles})."
244248
)
245249

246250
self._solver._kernel_set_particles_pos(

genesis/engine/entities/mpm_entity.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,9 @@ def set_pos(self, f, pos):
8989
f : int
9090
The current substep index.
9191
pos : gs.Tensor
92-
A tensor of shape (n_particles, 3) representing particle positions.
92+
A tensor of shape (n_envs, n_particles, 3) representing particle positions.
9393
"""
94-
self.solver._kernel_set_particles_pos(
95-
f,
96-
self._particle_start,
97-
self._n_particles,
98-
pos,
99-
)
94+
self.solver._kernel_set_particles_pos(f, self._particle_start, self._n_particles, pos)
10095

10196
def set_pos_grad(self, f, pos_grad):
10297
"""

genesis/engine/entities/sph_entity.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,7 @@ def set_pos(self, f, pos):
9292
pos : ndarray
9393
Array of particle positions of shape (n_envs, n_particles, 3).
9494
"""
95-
self.solver._kernel_set_particles_pos(
96-
f,
97-
self._particle_start,
98-
self._n_particles,
99-
pos,
100-
)
95+
self.solver._kernel_set_particles_pos(f, self._particle_start, self._n_particles, pos)
10196

10297
def set_pos_grad(self, f: ti.i32, pos_grad: ti.types.ndarray()):
10398
"""

genesis/engine/solvers/mpm_solver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -720,9 +720,9 @@ def _kernel_set_particles_pos(
720720
self.particles[f, i_global, i_b].pos[k] = pos[i_b, i_p, k]
721721

722722
# we restore these whenever directly setting positions
723-
self.particles[f, i_global, i_b].vel = ti.Vector.zero(gs.ti_float, 3)
723+
self.particles[f, i_global, i_b].vel.fill(0.0)
724724
self.particles[f, i_global, i_b].F = ti.Matrix.identity(gs.ti_float, 3)
725-
self.particles[f, i_global, i_b].C = ti.Matrix.zero(gs.ti_float, 3, 3)
725+
self.particles[f, i_global, i_b].C.fill(0.0)
726726
self.particles[f, i_global, i_b].Jp = self.particles_info[i_global].default_Jp
727727

728728
@ti.kernel

genesis/engine/solvers/pbd_solver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -840,8 +840,8 @@ def _kernel_set_particles_pos(
840840
for i_p, i_b in ti.ndrange(n_particles, self._B):
841841
i_global = i_p + particle_start
842842
for k in ti.static(range(3)):
843-
self.particles[i_global, i_b].pos[k] = pos[i_p, k]
844-
self.particles[i_global, i_b].vel = ti.Vector.zero(gs.ti_float, 3)
843+
self.particles[i_global, i_b].pos[k] = pos[i_b, i_p, k]
844+
self.particles[i_global, i_b].vel.fill(0.0)
845845

846846
@ti.kernel
847847
def _kernel_set_particles_vel(

genesis/engine/solvers/rigid/collider_decomp.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -613,12 +613,12 @@ def _func_contact_mpr_terrain(self, i_ga, i_gb, i_b):
613613
r_max = ti.min(self._solver.terrain_rc[0] - 1, r_max)
614614
c_max = ti.min(self._solver.terrain_rc[1] - 1, c_max)
615615

616-
cnt = 0
616+
n_con = 0
617617
for r in range(r_min, r_max):
618618
nvert = 0
619619
for c in range(c_min, c_max + 1):
620620
for i in range(2):
621-
if cnt < self._n_contacts_per_pair:
621+
if n_con < ti.static(self._n_contacts_per_pair):
622622
nvert = nvert + 1
623623
self.add_prism_vert(
624624
sh * (r + i) + self._solver.terrain_xyz_maxmin[3],
@@ -651,17 +651,17 @@ def _func_contact_mpr_terrain(self, i_ga, i_gb, i_b):
651651
contact_pos = contact_pos + gb_pos
652652

653653
valid = True
654-
i_col = self.n_contacts[i_b]
655-
for j in range(cnt):
654+
i_c = self.n_contacts[i_b]
655+
for j in range(n_con):
656656
if (
657-
contact_pos - self.contact_data[i_col - j - 1, i_b].pos
657+
contact_pos - self.contact_data[i_c - j - 1, i_b].pos
658658
).norm() < tolerance:
659659
valid = False
660660
break
661661

662662
if valid:
663663
self._func_add_contact(i_ga, i_gb, normal, contact_pos, penetration, i_b)
664-
cnt = cnt + 1
664+
n_con = n_con + 1
665665

666666
self._solver.geoms_state[i_ga, i_b].pos, self._solver.geoms_state[i_ga, i_b].quat = ga_pos, ga_quat
667667
self._solver.geoms_state[i_gb, i_b].pos, self._solver.geoms_state[i_gb, i_b].quat = gb_pos, gb_quat
@@ -787,12 +787,13 @@ def _func_broad_phase(self):
787787
self.contact_cache[i_ga, i_gb, i_b].normal.fill(0.0)
788788
continue
789789

790-
if self.n_broad_pairs[i_b] == self._max_collision_pairs:
790+
i_p = self.n_broad_pairs[i_b]
791+
if i_p == self._max_collision_pairs:
791792
# print(self._warn_msg_max_collision_pairs)
792793
break
793-
self.broad_collision_pairs[self.n_broad_pairs[i_b], i_b][0] = i_ga
794-
self.broad_collision_pairs[self.n_broad_pairs[i_b], i_b][1] = i_gb
795-
self.n_broad_pairs[i_b] = self.n_broad_pairs[i_b] + 1
794+
self.broad_collision_pairs[i_p, i_b][0] = i_ga
795+
self.broad_collision_pairs[i_p, i_b][1] = i_gb
796+
self.n_broad_pairs[i_b] += 1
796797

797798
self.active_buffer[n_active, i_b] = self.sort_buffer[i, i_b].i_g
798799
n_active = n_active + 1
@@ -1058,9 +1059,9 @@ def _func_narrow_phase_nonconvex_vs_nonterrain(self):
10581059

10591060
# Discard contact point is repeated
10601061
repeated = False
1061-
for i_con in range(n_con):
1062+
for i_c in range(n_con):
10621063
if not repeated:
1063-
idx_prev = self.n_contacts[i_b] - 1 - i_con
1064+
idx_prev = self.n_contacts[i_b] - 1 - i_c
10641065
prev_contact = self.contact_data[idx_prev, i_b].pos
10651066
if (contact_pos - prev_contact).norm() < tolerance:
10661067
repeated = True
@@ -1106,7 +1107,7 @@ def _func_plane_box_contact(self, i_ga, i_gb, i_b):
11061107
contact_pos_0 = contact_pos
11071108
tolerance = self._func_compute_tolerance(i_ga, i_gb, i_b)
11081109
for i_v in range(gb_info.vert_start, gb_info.vert_end):
1109-
if n_con < self._n_contacts_per_pair:
1110+
if n_con < ti.static(self._n_contacts_per_pair):
11101111
pos_corner = gu.ti_transform_by_trans_quat(
11111112
self._solver.verts_info[i_v].init_pos, gb_state.pos, gb_state.quat
11121113
)
@@ -1119,10 +1120,8 @@ def _func_plane_box_contact(self, i_ga, i_gb, i_b):
11191120

11201121
@ti.func
11211122
def _func_add_contact(self, i_ga, i_gb, normal, contact_pos, penetration, i_b):
1122-
# print(f"Adding contact {i_ga} {i_gb}, normal:", normal, "contact_pos:", contact_pos, "penetration:", penetration)
1123-
i_col = self.n_contacts[i_b]
1124-
1125-
if i_col == self._max_contact_pairs:
1123+
i_c = self.n_contacts[i_b]
1124+
if i_c == self._max_contact_pairs:
11261125
# FIXME: 'ti.static_print' cannot be used as it will be printed systematically, completely ignoring guard
11271126
# condition, while 'print' is slowing down the kernel even if every called in practice...
11281127
# print(self._warn_msg_max_collision_pairs)
@@ -1135,17 +1134,17 @@ def _func_add_contact(self, i_ga, i_gb, normal, contact_pos, penetration, i_b):
11351134
friction_b = gb_info.friction * self._solver.geoms_state[i_gb, i_b].friction_ratio
11361135

11371136
# b to a
1138-
self.contact_data[i_col, i_b].geom_a = i_ga
1139-
self.contact_data[i_col, i_b].geom_b = i_gb
1140-
self.contact_data[i_col, i_b].normal = normal
1141-
self.contact_data[i_col, i_b].pos = contact_pos
1142-
self.contact_data[i_col, i_b].penetration = penetration
1143-
self.contact_data[i_col, i_b].friction = ti.max(ti.max(friction_a, friction_b), 1e-2)
1144-
self.contact_data[i_col, i_b].sol_params = 0.5 * (ga_info.sol_params + gb_info.sol_params)
1145-
self.contact_data[i_col, i_b].link_a = ga_info.link_idx
1146-
self.contact_data[i_col, i_b].link_b = gb_info.link_idx
1147-
1148-
self.n_contacts[i_b] = i_col + 1
1137+
self.contact_data[i_c, i_b].geom_a = i_ga
1138+
self.contact_data[i_c, i_b].geom_b = i_gb
1139+
self.contact_data[i_c, i_b].normal = normal
1140+
self.contact_data[i_c, i_b].pos = contact_pos
1141+
self.contact_data[i_c, i_b].penetration = penetration
1142+
self.contact_data[i_c, i_b].friction = ti.max(ti.max(friction_a, friction_b), 1e-2)
1143+
self.contact_data[i_c, i_b].sol_params = 0.5 * (ga_info.sol_params + gb_info.sol_params)
1144+
self.contact_data[i_c, i_b].link_a = ga_info.link_idx
1145+
self.contact_data[i_c, i_b].link_b = gb_info.link_idx
1146+
1147+
self.n_contacts[i_b] = i_c + 1
11491148

11501149
@ti.func
11511150
def _func_compute_tolerance(self, i_ga, i_gb, i_b):
@@ -1325,7 +1324,7 @@ def _func_convex_convex_contact(self, i_ga, i_gb, i_b):
13251324
# add the discovered contact points and stop multi-contact search.
13261325
for i_c in range(n_contacts):
13271326
# Ignore contact points if the number of contacts exceeds the limit.
1328-
if i_c < self._n_contacts_per_pair:
1327+
if i_c < ti.static(self._n_contacts_per_pair):
13291328
contact_pos = self._gjk.contact_pos[i_b, i_c]
13301329
normal = self._gjk.normal[i_b, i_c]
13311330
self._func_add_contact(i_ga, i_gb, normal, contact_pos, penetration, i_b)
@@ -1453,9 +1452,9 @@ def _func_convex_convex_contact(self, i_ga, i_gb, i_b):
14531452

14541453
# Discard contact point is repeated
14551454
repeated = False
1456-
for i_con in range(n_con):
1455+
for i_c in range(n_con):
14571456
if not repeated:
1458-
idx_prev = self.n_contacts[i_b] - 1 - i_con
1457+
idx_prev = self.n_contacts[i_b] - 1 - i_c
14591458
prev_contact = self.contact_data[idx_prev, i_b].pos
14601459
if (contact_pos - prev_contact).norm() < tolerance:
14611460
repeated = True

0 commit comments

Comments
 (0)