Skip to content

Commit d1e64cf

Browse files
authored
Merge branch 'main' into yiling/250713_aux_kernels
2 parents f2ce039 + 47bf4be commit d1e64cf

File tree

5 files changed

+1857
-963
lines changed

5 files changed

+1857
-963
lines changed

genesis/engine/solvers/rigid/array_class.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,165 @@ def __init__(self, f_batch):
224224
self.simplex_size = ti.field(gs.ti_int, shape=f_batch())
225225

226226

227+
# =========================================== GJK ===========================================
228+
@ti.data_oriented
229+
class GJKState:
230+
def __init__(self, solver, static_rigid_sim_config, gjk_static_config):
231+
_B = solver._B
232+
polytope_max_faces = gjk_static_config.polytope_max_faces
233+
max_contacts_per_pair = gjk_static_config.max_contacts_per_pair
234+
max_contact_polygon_verts = gjk_static_config.max_contact_polygon_verts
235+
236+
# Cache to store the previous support points for support mesh function.
237+
self.support_mesh_prev_vertex_id = ti.field(dtype=gs.ti_int, shape=(_B, 2))
238+
239+
### GJK simplex
240+
struct_simplex_vertex = ti.types.struct(
241+
# Support points on the two objects
242+
obj1=gs.ti_vec3,
243+
obj2=gs.ti_vec3,
244+
# Support point IDs on the two objects
245+
id1=gs.ti_int,
246+
id2=gs.ti_int,
247+
# Vertex on Minkowski difference
248+
mink=gs.ti_vec3,
249+
)
250+
struct_simplex = ti.types.struct(
251+
# Number of vertices in the simplex
252+
nverts=gs.ti_int,
253+
# Distance from the origin to the simplex
254+
dist=gs.ti_float,
255+
)
256+
struct_simplex_buffer = ti.types.struct(
257+
# Normals of the simplex faces
258+
normal=gs.ti_vec3,
259+
# Signed distances of the simplex faces from the origin
260+
sdist=gs.ti_float,
261+
)
262+
self.simplex_vertex = struct_simplex_vertex.field(shape=(_B, 4))
263+
self.simplex_buffer = struct_simplex_buffer.field(shape=(_B, 4))
264+
self.simplex = struct_simplex.field(shape=(_B,))
265+
266+
# Only when we enable MuJoCo compatibility, we use the simplex vertex and buffer for intersection checks.
267+
if static_rigid_sim_config.enable_mujoco_compatibility:
268+
self.simplex_vertex_intersect = struct_simplex_vertex.field(shape=(_B, 4))
269+
self.simplex_buffer_intersect = struct_simplex_buffer.field(shape=(_B, 4))
270+
self.nsimplex = ti.field(dtype=gs.ti_int, shape=(_B,))
271+
272+
# In safe GJK, if the initial simplex is degenerate and the geometries are discrete, we go through vertices
273+
# on the Minkowski difference to find a vertex that would make a valid simplex. To prevent iterating through
274+
# the same vertices again during initial simplex construction, we keep the vertex ID of the last vertex that
275+
# we searched, so that we can start searching from the next vertex.
276+
self.last_searched_simplex_vertex_id = ti.field(dtype=gs.ti_int, shape=(_B,))
277+
278+
### EPA polytope
279+
struct_polytope_vertex = struct_simplex_vertex
280+
struct_polytope_face = ti.types.struct(
281+
# Indices of the vertices forming the face on the polytope
282+
verts_idx=gs.ti_ivec3,
283+
# Indices of adjacent faces, one for each edge: [v1,v2], [v2,v3], [v3,v1]
284+
adj_idx=gs.ti_ivec3,
285+
# Projection of the origin onto the face, can be used as face normal
286+
normal=gs.ti_vec3,
287+
# Square of 2-norm of the normal vector, negative means deleted face
288+
dist2=gs.ti_float,
289+
# Index of the face in the polytope map, -1 for not in the map, -2 for deleted
290+
map_idx=gs.ti_int,
291+
)
292+
# Horizon is used for representing the faces to delete when the polytope is expanded by inserting a new vertex.
293+
struct_polytope_horizon_data = ti.types.struct(
294+
# Indices of faces on horizon
295+
face_idx=gs.ti_int,
296+
# Corresponding edge of each face on the horizon
297+
edge_idx=gs.ti_int,
298+
)
299+
struct_polytope = ti.types.struct(
300+
# Number of vertices in the polytope
301+
nverts=gs.ti_int,
302+
# Number of faces in the polytope (it could include deleted faces)
303+
nfaces=gs.ti_int,
304+
# Number of faces in the polytope map (only valid faces on polytope)
305+
nfaces_map=gs.ti_int,
306+
# Number of edges in the horizon
307+
horizon_nedges=gs.ti_int,
308+
# Support point on the Minkowski difference where the horizon is created
309+
horizon_w=gs.ti_vec3,
310+
)
311+
312+
self.polytope = struct_polytope.field(shape=(_B,))
313+
self.polytope_verts = struct_polytope_vertex.field(shape=(_B, 5 + gjk_static_config.epa_max_iterations))
314+
self.polytope_faces = struct_polytope_face.field(shape=(_B, polytope_max_faces))
315+
self.polytope_horizon_data = struct_polytope_horizon_data.field(
316+
shape=(_B, 6 + gjk_static_config.epa_max_iterations)
317+
)
318+
319+
# Face indices that form the polytope. The first [nfaces_map] indices are the faces that form the polytope.
320+
self.polytope_faces_map = ti.Vector.field(n=polytope_max_faces, dtype=gs.ti_int, shape=(_B,))
321+
322+
# Stack to use for visiting faces during the horizon construction. The size is (# max faces * 3),
323+
# because a face has 3 edges.
324+
self.polytope_horizon_stack = struct_polytope_horizon_data.field(shape=(_B, polytope_max_faces * 3))
325+
326+
# Data structures for multi-contact detection based on MuJoCo's implementation.
327+
if gjk_static_config.enable_mujoco_multi_contact:
328+
struct_contact_face = ti.types.struct(
329+
# Vertices from the two colliding faces
330+
vert1=gs.ti_vec3,
331+
vert2=gs.ti_vec3,
332+
endverts=gs.ti_vec3,
333+
# Normals of the two colliding faces
334+
normal1=gs.ti_vec3,
335+
normal2=gs.ti_vec3,
336+
# Face ID of the two colliding faces
337+
id1=gs.ti_int,
338+
id2=gs.ti_int,
339+
)
340+
# Struct for storing temp. contact normals
341+
struct_contact_normal = ti.types.struct(
342+
endverts=gs.ti_vec3,
343+
# Normal vector of the contact point
344+
normal=gs.ti_vec3,
345+
# Face ID
346+
id=gs.ti_int,
347+
)
348+
struct_contact_halfspace = ti.types.struct(
349+
# Halfspace normal
350+
normal=gs.ti_vec3,
351+
# Halfspace distance from the origin
352+
dist=gs.ti_float,
353+
)
354+
self.contact_faces = struct_contact_face.field(shape=(_B, max_contact_polygon_verts))
355+
self.contact_normals = struct_contact_normal.field(shape=(_B, max_contact_polygon_verts))
356+
self.contact_halfspaces = struct_contact_halfspace.field(shape=(_B, max_contact_polygon_verts))
357+
self.contact_clipped_polygons = gs.ti_vec3.field(shape=(_B, 2, max_contact_polygon_verts))
358+
359+
# Whether or not the MuJoCo's contact manifold detection algorithm was used for the current pair.
360+
self.multi_contact_flag = ti.field(dtype=gs.ti_int, shape=(_B,))
361+
362+
### Final results
363+
# Witness information
364+
struct_witness = ti.types.struct(
365+
# Witness points on the two objects
366+
point_obj1=gs.ti_vec3,
367+
point_obj2=gs.ti_vec3,
368+
)
369+
self.witness = struct_witness.field(shape=(_B, max_contacts_per_pair))
370+
self.n_witness = ti.field(dtype=gs.ti_int, shape=(_B,))
371+
372+
# Contact information, the namings are the same as those from the calling function. Even if they could be
373+
# redundant, we keep them for easier use from the calling function.
374+
self.n_contacts = ti.field(dtype=gs.ti_int, shape=(_B,))
375+
self.contact_pos = gs.ti_vec3.field(shape=(_B, max_contacts_per_pair))
376+
self.normal = gs.ti_vec3.field(shape=(_B, max_contacts_per_pair))
377+
self.is_col = ti.field(dtype=gs.ti_int, shape=(_B,))
378+
self.penetration = ti.field(dtype=gs.ti_float, shape=(_B,))
379+
380+
# Distance between the two objects.
381+
# If the objects are separated, the distance is positive.
382+
# If the objects are intersecting, the distance is negative (depth).
383+
self.distance = ti.field(dtype=gs.ti_float, shape=(_B,))
384+
385+
227386
# =========================================== SupportField ===========================================
228387
@ti.data_oriented
229388
class SupportFieldInfo:

genesis/engine/solvers/rigid/collider_decomp.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -367,13 +367,16 @@ def detection(self) -> None:
367367
self._solver.geoms_info,
368368
self._solver.geoms_init_AABB,
369369
self._solver.verts_info,
370+
self._solver.faces_info,
370371
self._solver._rigid_global_info,
371372
self._solver._static_rigid_sim_config,
372373
self._collider_state,
373374
self._collider_info,
374375
self._collider_static_config,
375376
self._mpr._mpr_state,
376377
self._mpr._mpr_static_config,
378+
self._gjk._gjk_state if self._gjk is not None else None,
379+
self._gjk._gjk_static_config if self._gjk is not None else None,
377380
self._support_field._support_field_info,
378381
self._support_field._support_field_static_config,
379382
self._mpr,
@@ -1171,13 +1174,16 @@ def _func_narrow_phase_convex_vs_convex(
11711174
geoms_info: array_class.GeomsInfo,
11721175
geoms_init_AABB: array_class.GeomsInitAABB,
11731176
verts_info: array_class.VertsInfo,
1177+
faces_info: array_class.FacesInfo,
11741178
rigid_global_info: ti.template(),
11751179
static_rigid_sim_config: ti.template(),
11761180
collider_state: ti.template(),
11771181
collider_info: ti.template(),
11781182
collider_static_config: ti.template(),
11791183
mpr_state: ti.template(),
11801184
mpr_static_config: ti.template(),
1185+
gjk_state: ti.template(),
1186+
gjk_static_config: ti.template(),
11811187
support_field_info: ti.template(),
11821188
support_field_static_config: ti.template(),
11831189
# FIXME: We need mpr, gjk, sdf, and support_field for now to call their class functions. After migration is
@@ -1224,13 +1230,16 @@ def _func_narrow_phase_convex_vs_convex(
12241230
geoms_info,
12251231
geoms_init_AABB,
12261232
verts_info,
1233+
faces_info,
12271234
rigid_global_info,
12281235
static_rigid_sim_config,
12291236
collider_state,
12301237
collider_info,
12311238
collider_static_config,
12321239
mpr_state,
12331240
mpr_static_config,
1241+
gjk_state,
1242+
gjk_static_config,
12341243
support_field_info,
12351244
support_field_static_config,
12361245
mpr,
@@ -1252,13 +1261,16 @@ def _func_narrow_phase_convex_vs_convex(
12521261
geoms_info,
12531262
geoms_init_AABB,
12541263
verts_info,
1264+
faces_info,
12551265
rigid_global_info,
12561266
static_rigid_sim_config,
12571267
collider_state,
12581268
collider_info,
12591269
collider_static_config,
12601270
mpr_state,
12611271
mpr_static_config,
1272+
gjk_state,
1273+
gjk_static_config,
12621274
support_field_info,
12631275
support_field_static_config,
12641276
mpr,
@@ -1768,13 +1780,16 @@ def _func_convex_convex_contact(
17681780
geoms_info: array_class.GeomsInfo,
17691781
geoms_init_AABB: array_class.GeomsInitAABB,
17701782
verts_info: array_class.VertsInfo,
1783+
faces_info: array_class.FacesInfo,
17711784
rigid_global_info: ti.template(),
17721785
static_rigid_sim_config: ti.template(),
17731786
collider_state: ti.template(),
17741787
collider_info: ti.template(),
17751788
collider_static_config: ti.template(),
17761789
mpr_state: ti.template(),
17771790
mpr_static_config: ti.template(),
1791+
gjk_state: ti.template(),
1792+
gjk_static_config: ti.template(),
17781793
support_field_info: ti.template(),
17791794
support_field_static_config: ti.template(),
17801795
mpr: ti.template(),
@@ -1933,21 +1948,37 @@ def _func_convex_convex_contact(
19331948
elif ti.static(
19341949
collider_static_config.ccd_algorithm in (CCD_ALGORITHM_CODE.GJK, CCD_ALGORITHM_CODE.MJ_GJK)
19351950
):
1936-
gjk.func_gjk_contact(i_ga, i_gb, i_b)
1951+
gjk.func_gjk_contact(
1952+
geoms_state,
1953+
geoms_info,
1954+
verts_info,
1955+
faces_info,
1956+
static_rigid_sim_config,
1957+
collider_state,
1958+
collider_static_config,
1959+
gjk_state,
1960+
gjk_static_config,
1961+
support_field_info,
1962+
support_field_static_config,
1963+
support_field,
1964+
i_ga,
1965+
i_gb,
1966+
i_b,
1967+
)
19371968

1938-
is_col = gjk.is_col[i_b] == 1
1939-
penetration = gjk.penetration[i_b]
1940-
n_contacts = gjk.n_contacts[i_b]
1969+
is_col = gjk_state.is_col[i_b] == 1
1970+
penetration = gjk_state.penetration[i_b]
1971+
n_contacts = gjk_state.n_contacts[i_b]
19411972

19421973
if is_col:
1943-
if gjk.multi_contact_flag[i_b]:
1974+
if gjk_state.multi_contact_flag[i_b]:
19441975
# Used MuJoCo's multi-contact algorithm to find multiple contact points. Therefore,
19451976
# add the discovered contact points and stop multi-contact search.
19461977
for i_c in range(n_contacts):
19471978
# Ignore contact points if the number of contacts exceeds the limit.
19481979
if i_c < ti.static(collider_static_config.n_contacts_per_pair):
1949-
contact_pos = gjk.contact_pos[i_b, i_c]
1950-
normal = gjk.normal[i_b, i_c]
1980+
contact_pos = gjk_state.contact_pos[i_b, i_c]
1981+
normal = gjk_state.normal[i_b, i_c]
19511982
self_unused._func_add_contact(
19521983
geoms_state,
19531984
geoms_info,
@@ -1963,8 +1994,8 @@ def _func_convex_convex_contact(
19631994

19641995
break
19651996
else:
1966-
contact_pos = gjk.contact_pos[i_b, 0]
1967-
normal = gjk.normal[i_b, 0]
1997+
contact_pos = gjk_state.contact_pos[i_b, 0]
1998+
normal = gjk_state.normal[i_b, 0]
19681999

19692000
if ti.static(collider_static_config.ccd_algorithm == CCD_ALGORITHM_CODE.MPR):
19702001
if try_sdf:

0 commit comments

Comments
 (0)