Skip to content

Commit 67b1420

Browse files
authored
[Migration] MPR migration (#1412)
1 parent 76da6db commit 67b1420

File tree

12 files changed

+994
-396
lines changed

12 files changed

+994
-396
lines changed

genesis/engine/solvers/rigid/array_class.py

Lines changed: 75 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(self, solver, n_dofs: int, n_entities: int, n_geoms: int, _B: int,
4040
# self.links_T = ti.Matrix.field(n=4, m=4, dtype=gs.ti_float, shape=solver.n_links)
4141

4242

43-
# =========================================== Collider ===========================================
43+
# =========================================== Constraint ===========================================
4444

4545

4646
@ti.data_oriented
@@ -54,13 +54,16 @@ def __init__(self, solver):
5454
self.n_constraints = ti.field(dtype=gs.ti_int, shape=f_batch())
5555

5656

57+
# =========================================== Collider ===========================================
58+
59+
5760
@ti.data_oriented
5861
class ColliderState:
5962
"""
60-
Class to store the mutable collider data, all of which type is [ti.fields].
63+
Class to store the MUTABLE collider data, all of which type is [ti.fields] (later we will support NDArrays).
6164
"""
6265

63-
def __init__(self, solver, n_possible_pairs, n_vert_neighbors, collider_info):
66+
def __init__(self, solver, n_possible_pairs, collider_static_config):
6467
"""
6568
Parameters:
6669
----------
@@ -73,17 +76,11 @@ def __init__(self, solver, n_possible_pairs, n_vert_neighbors, collider_info):
7376
_B = solver._B
7477
f_batch = solver._batch_shape
7578
n_geoms = solver.n_geoms_
76-
n_verts = solver.n_verts_
7779
max_collision_pairs = min(solver._max_collision_pairs, n_possible_pairs)
78-
max_contact_pairs = max_collision_pairs * collider_info.n_contacts_per_pair
80+
max_contact_pairs = max_collision_pairs * collider_static_config.n_contacts_per_pair
7981
use_hibernation = solver._static_rigid_sim_config.use_hibernation
8082
box_box_detection = solver._static_rigid_sim_config.box_box_detection
8183

82-
############## vertex connectivity ##############
83-
self.vert_neighbors = ti.field(dtype=gs.ti_int, shape=max(1, n_vert_neighbors))
84-
self.vert_neighbor_start = ti.field(dtype=gs.ti_int, shape=n_verts)
85-
self.vert_n_neighbors = ti.field(dtype=gs.ti_int, shape=n_verts)
86-
8784
############## broad phase SAP ##############
8885
# This buffer stores the AABBs along the search axis of all geoms
8986
struct_sort_buffer = ti.types.struct(value=gs.ti_float, i_g=gs.ti_int, is_max=gs.ti_int)
@@ -95,9 +92,6 @@ def __init__(self, solver, n_possible_pairs, n_vert_neighbors, collider_info):
9592
self.active_buffer_hib = ti.field(dtype=gs.ti_int, shape=f_batch(n_geoms))
9693
self.active_buffer = ti.field(dtype=gs.ti_int, shape=f_batch(n_geoms))
9794

98-
# Stores the validity of the collision pairs
99-
self.collision_pair_validity = ti.field(dtype=gs.ti_int, shape=(n_geoms, n_geoms))
100-
10195
# Whether or not this is the first time to run the broad phase for each batch
10296
self.first_time = ti.field(gs.ti_int, shape=_B)
10397

@@ -145,8 +139,8 @@ def __init__(self, solver, n_possible_pairs, n_vert_neighbors, collider_info):
145139
# located depending of the pose and size of each box. In practice, up to 11 contact points have been
146140
# observed. The theoretical worst case scenario would be 2 cubes roughly the same size and same center,
147141
# with transform RPY = (45, 45, 45), resulting in 3 contact points per faces for a total of 16 points.
148-
self.box_depth = ti.field(dtype=gs.ti_float, shape=f_batch(collider_info.box_MAXCONPAIR))
149-
self.box_points = ti.field(gs.ti_vec3, shape=f_batch(collider_info.box_MAXCONPAIR))
142+
self.box_depth = ti.field(dtype=gs.ti_float, shape=f_batch(collider_static_config.box_MAXCONPAIR))
143+
self.box_points = ti.field(gs.ti_vec3, shape=f_batch(collider_static_config.box_MAXCONPAIR))
150144
self.box_pts = ti.field(gs.ti_vec3, shape=f_batch(6))
151145
self.box_lines = ti.field(gs.ti_vec6, shape=f_batch(4))
152146
self.box_linesu = ti.field(gs.ti_vec6, shape=f_batch(4))
@@ -155,7 +149,44 @@ def __init__(self, solver, n_possible_pairs, n_vert_neighbors, collider_info):
155149
self.box_pu = ti.field(gs.ti_vec3, shape=f_batch(4))
156150

157151
########## Terrain contact detection ##########
158-
if collider_info.has_terrain:
152+
if collider_static_config.has_terrain:
153+
# for faster compilation
154+
self.xyz_max_min = ti.field(dtype=gs.ti_float, shape=f_batch(6))
155+
self.prism = ti.field(dtype=gs.ti_vec3, shape=f_batch(6))
156+
157+
158+
@ti.data_oriented
159+
class ColliderInfo:
160+
"""
161+
Class to store the IMMUTABLE collider data, all of which type is [ti.fields] (later we will support NDArrays).
162+
"""
163+
164+
def __init__(self, solver, n_vert_neighbors, collider_static_config):
165+
"""
166+
Parameters:
167+
----------
168+
n_vert_neighbors: int
169+
Size of the vertex neighbors array.
170+
"""
171+
n_geoms = solver.n_geoms_
172+
n_verts = solver.n_verts_
173+
174+
############## vertex connectivity ##############
175+
self.vert_neighbors = ti.field(dtype=gs.ti_int, shape=max(1, n_vert_neighbors))
176+
self.vert_neighbor_start = ti.field(dtype=gs.ti_int, shape=n_verts)
177+
self.vert_n_neighbors = ti.field(dtype=gs.ti_int, shape=n_verts)
178+
179+
############## broad phase SAP ##############
180+
# Stores the validity of the collision pairs
181+
self.collision_pair_validity = ti.field(dtype=gs.ti_int, shape=(n_geoms, n_geoms))
182+
183+
# Number of possible pairs of collision, store them in a field to avoid recompilation
184+
self._max_possible_pairs = ti.field(dtype=gs.ti_int, shape=())
185+
self._max_collision_pairs = ti.field(dtype=gs.ti_int, shape=())
186+
self._max_contact_pairs = ti.field(dtype=gs.ti_int, shape=())
187+
188+
########## Terrain contact detection ##########
189+
if collider_static_config.has_terrain:
159190
links_idx = solver.geoms_info.link_idx.to_numpy()[solver.geoms_info.type.to_numpy() == gs.GEOM_TYPE.TERRAIN]
160191
entity = solver._entities[solver.links_info.entity_idx.to_numpy()[links_idx[0]]]
161192

@@ -164,6 +195,31 @@ def __init__(self, solver, n_possible_pairs, n_vert_neighbors, collider_info):
164195
self.terrain_scale = ti.field(dtype=gs.ti_float, shape=2)
165196
self.terrain_xyz_maxmin = ti.field(dtype=gs.ti_float, shape=6)
166197

167-
# for faster compilation
168-
self.xyz_max_min = ti.field(dtype=gs.ti_float, shape=f_batch(6))
169-
self.prism = ti.field(dtype=gs.ti_vec3, shape=f_batch(6))
198+
199+
# =========================================== MPR ===========================================
200+
@ti.data_oriented
201+
class MPRState:
202+
def __init__(self, f_batch):
203+
struct_support = ti.types.struct(
204+
v1=gs.ti_vec3,
205+
v2=gs.ti_vec3,
206+
v=gs.ti_vec3,
207+
)
208+
self.simplex_support = struct_support.field(
209+
shape=f_batch(4),
210+
layout=ti.Layout.SOA,
211+
)
212+
self.simplex_size = ti.field(gs.ti_int, shape=f_batch())
213+
214+
215+
# =========================================== SupportField ===========================================
216+
@ti.data_oriented
217+
class SupportFieldInfo:
218+
"""
219+
Class to store the IMMUTABLE support field data, all of which type is [ti.fields] (later we will support NDArrays).
220+
"""
221+
222+
def __init__(self, n_geoms, n_support_cells):
223+
self.support_cell_start = ti.field(dtype=gs.ti_int, shape=n_geoms)
224+
self.support_v = ti.Vector.field(3, dtype=gs.ti_float, shape=max(1, n_support_cells))
225+
self.support_vid = ti.field(dtype=gs.ti_int, shape=max(1, n_support_cells))

0 commit comments

Comments
 (0)