@@ -40,7 +40,7 @@ def __init__(self, solver, n_dofs: int, n_entities: int, n_geoms: int, _B: int,
4040 # self.links_T = ti.Matrix.field(n=4, m=4, dtype=gs.ti_float, shape=solver.n_links)
4141
4242
43- # =========================================== Collider ===========================================
43+ # =========================================== Constraint ===========================================
4444
4545
4646@ti .data_oriented
@@ -54,13 +54,16 @@ def __init__(self, solver):
5454 self .n_constraints = ti .field (dtype = gs .ti_int , shape = f_batch ())
5555
5656
57+ # =========================================== Collider ===========================================
58+
59+
5760@ti .data_oriented
5861class ColliderState :
5962 """
60- Class to store the mutable collider data, all of which type is [ti.fields].
63+ Class to store the MUTABLE collider data, all of which type is [ti.fields] (later we will support NDArrays) .
6164 """
6265
63- def __init__ (self , solver , n_possible_pairs , n_vert_neighbors , collider_info ):
66+ def __init__ (self , solver , n_possible_pairs , collider_static_config ):
6467 """
6568 Parameters:
6669 ----------
@@ -73,17 +76,11 @@ def __init__(self, solver, n_possible_pairs, n_vert_neighbors, collider_info):
7376 _B = solver ._B
7477 f_batch = solver ._batch_shape
7578 n_geoms = solver .n_geoms_
76- n_verts = solver .n_verts_
7779 max_collision_pairs = min (solver ._max_collision_pairs , n_possible_pairs )
78- max_contact_pairs = max_collision_pairs * collider_info .n_contacts_per_pair
80+ max_contact_pairs = max_collision_pairs * collider_static_config .n_contacts_per_pair
7981 use_hibernation = solver ._static_rigid_sim_config .use_hibernation
8082 box_box_detection = solver ._static_rigid_sim_config .box_box_detection
8183
82- ############## vertex connectivity ##############
83- self .vert_neighbors = ti .field (dtype = gs .ti_int , shape = max (1 , n_vert_neighbors ))
84- self .vert_neighbor_start = ti .field (dtype = gs .ti_int , shape = n_verts )
85- self .vert_n_neighbors = ti .field (dtype = gs .ti_int , shape = n_verts )
86-
8784 ############## broad phase SAP ##############
8885 # This buffer stores the AABBs along the search axis of all geoms
8986 struct_sort_buffer = ti .types .struct (value = gs .ti_float , i_g = gs .ti_int , is_max = gs .ti_int )
@@ -95,9 +92,6 @@ def __init__(self, solver, n_possible_pairs, n_vert_neighbors, collider_info):
9592 self .active_buffer_hib = ti .field (dtype = gs .ti_int , shape = f_batch (n_geoms ))
9693 self .active_buffer = ti .field (dtype = gs .ti_int , shape = f_batch (n_geoms ))
9794
98- # Stores the validity of the collision pairs
99- self .collision_pair_validity = ti .field (dtype = gs .ti_int , shape = (n_geoms , n_geoms ))
100-
10195 # Whether or not this is the first time to run the broad phase for each batch
10296 self .first_time = ti .field (gs .ti_int , shape = _B )
10397
@@ -145,8 +139,8 @@ def __init__(self, solver, n_possible_pairs, n_vert_neighbors, collider_info):
145139 # located depending of the pose and size of each box. In practice, up to 11 contact points have been
146140 # observed. The theoretical worst case scenario would be 2 cubes roughly the same size and same center,
147141 # with transform RPY = (45, 45, 45), resulting in 3 contact points per faces for a total of 16 points.
148- self .box_depth = ti .field (dtype = gs .ti_float , shape = f_batch (collider_info .box_MAXCONPAIR ))
149- self .box_points = ti .field (gs .ti_vec3 , shape = f_batch (collider_info .box_MAXCONPAIR ))
142+ self .box_depth = ti .field (dtype = gs .ti_float , shape = f_batch (collider_static_config .box_MAXCONPAIR ))
143+ self .box_points = ti .field (gs .ti_vec3 , shape = f_batch (collider_static_config .box_MAXCONPAIR ))
150144 self .box_pts = ti .field (gs .ti_vec3 , shape = f_batch (6 ))
151145 self .box_lines = ti .field (gs .ti_vec6 , shape = f_batch (4 ))
152146 self .box_linesu = ti .field (gs .ti_vec6 , shape = f_batch (4 ))
@@ -155,7 +149,44 @@ def __init__(self, solver, n_possible_pairs, n_vert_neighbors, collider_info):
155149 self .box_pu = ti .field (gs .ti_vec3 , shape = f_batch (4 ))
156150
157151 ########## Terrain contact detection ##########
158- if collider_info .has_terrain :
152+ if collider_static_config .has_terrain :
153+ # for faster compilation
154+ self .xyz_max_min = ti .field (dtype = gs .ti_float , shape = f_batch (6 ))
155+ self .prism = ti .field (dtype = gs .ti_vec3 , shape = f_batch (6 ))
156+
157+
158+ @ti .data_oriented
159+ class ColliderInfo :
160+ """
161+ Class to store the IMMUTABLE collider data, all of which type is [ti.fields] (later we will support NDArrays).
162+ """
163+
164+ def __init__ (self , solver , n_vert_neighbors , collider_static_config ):
165+ """
166+ Parameters:
167+ ----------
168+ n_vert_neighbors: int
169+ Size of the vertex neighbors array.
170+ """
171+ n_geoms = solver .n_geoms_
172+ n_verts = solver .n_verts_
173+
174+ ############## vertex connectivity ##############
175+ self .vert_neighbors = ti .field (dtype = gs .ti_int , shape = max (1 , n_vert_neighbors ))
176+ self .vert_neighbor_start = ti .field (dtype = gs .ti_int , shape = n_verts )
177+ self .vert_n_neighbors = ti .field (dtype = gs .ti_int , shape = n_verts )
178+
179+ ############## broad phase SAP ##############
180+ # Stores the validity of the collision pairs
181+ self .collision_pair_validity = ti .field (dtype = gs .ti_int , shape = (n_geoms , n_geoms ))
182+
183+ # Number of possible pairs of collision, store them in a field to avoid recompilation
184+ self ._max_possible_pairs = ti .field (dtype = gs .ti_int , shape = ())
185+ self ._max_collision_pairs = ti .field (dtype = gs .ti_int , shape = ())
186+ self ._max_contact_pairs = ti .field (dtype = gs .ti_int , shape = ())
187+
188+ ########## Terrain contact detection ##########
189+ if collider_static_config .has_terrain :
159190 links_idx = solver .geoms_info .link_idx .to_numpy ()[solver .geoms_info .type .to_numpy () == gs .GEOM_TYPE .TERRAIN ]
160191 entity = solver ._entities [solver .links_info .entity_idx .to_numpy ()[links_idx [0 ]]]
161192
@@ -164,6 +195,31 @@ def __init__(self, solver, n_possible_pairs, n_vert_neighbors, collider_info):
164195 self .terrain_scale = ti .field (dtype = gs .ti_float , shape = 2 )
165196 self .terrain_xyz_maxmin = ti .field (dtype = gs .ti_float , shape = 6 )
166197
167- # for faster compilation
168- self .xyz_max_min = ti .field (dtype = gs .ti_float , shape = f_batch (6 ))
169- self .prism = ti .field (dtype = gs .ti_vec3 , shape = f_batch (6 ))
198+
199+ # =========================================== MPR ===========================================
200+ @ti .data_oriented
201+ class MPRState :
202+ def __init__ (self , f_batch ):
203+ struct_support = ti .types .struct (
204+ v1 = gs .ti_vec3 ,
205+ v2 = gs .ti_vec3 ,
206+ v = gs .ti_vec3 ,
207+ )
208+ self .simplex_support = struct_support .field (
209+ shape = f_batch (4 ),
210+ layout = ti .Layout .SOA ,
211+ )
212+ self .simplex_size = ti .field (gs .ti_int , shape = f_batch ())
213+
214+
215+ # =========================================== SupportField ===========================================
216+ @ti .data_oriented
217+ class SupportFieldInfo :
218+ """
219+ Class to store the IMMUTABLE support field data, all of which type is [ti.fields] (later we will support NDArrays).
220+ """
221+
222+ def __init__ (self , n_geoms , n_support_cells ):
223+ self .support_cell_start = ti .field (dtype = gs .ti_int , shape = n_geoms )
224+ self .support_v = ti .Vector .field (3 , dtype = gs .ti_float , shape = max (1 , n_support_cells ))
225+ self .support_vid = ti .field (dtype = gs .ti_int , shape = max (1 , n_support_cells ))
0 commit comments