Skip to content

Commit 8ff9b4e

Browse files
committed
removed initialization logic from collider state
1 parent 8f4eb84 commit 8ff9b4e

File tree

5 files changed

+180
-148
lines changed

5 files changed

+180
-148
lines changed

genesis/engine/solvers/rigid/array_class.py

Lines changed: 26 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,34 @@ class ColliderState:
3434
Class to store the mutable collider data, all of which type is [ti.fields].
3535
"""
3636

37-
def __init__(self, solver, collider_info):
37+
def __init__(self, solver, n_possible_pairs, n_vert_neighbors, collider_info):
38+
"""
39+
Parameters:
40+
----------
41+
n_possible_pairs: int
42+
Maximum number of possible collision pairs based on geom configurations. For instance, when adjacent
43+
collision is disabled, adjacent geoms are not considered in counting possible pairs.
44+
n_vert_neighbors: int
45+
Size of the vertex neighbors array.
46+
"""
3847
_B = solver._B
3948
f_batch = solver._batch_shape
4049
n_geoms = solver.n_geoms_
41-
max_collision_pairs = solver._max_collision_pairs
50+
n_verts = solver.n_verts_
51+
max_collision_pairs = min(solver._max_collision_pairs, n_possible_pairs)
52+
max_contact_pairs = max_collision_pairs * collider_info.n_contacts_per_pair
4253
use_hibernation = solver._static_rigid_sim_config.use_hibernation
54+
box_box_detection = solver._static_rigid_sim_config.box_box_detection
4355

4456
############## vertex connectivity ##############
45-
self._init_verts_connectivity(solver)
57+
self.vert_neighbors = ti.field(dtype=gs.ti_int, shape=max(1, n_vert_neighbors))
58+
self.vert_neighbor_start = ti.field(dtype=gs.ti_int, shape=n_verts)
59+
self.vert_n_neighbors = ti.field(dtype=gs.ti_int, shape=n_verts)
4660

4761
############## broad phase SAP ##############
4862
# This buffer stores the AABBs along the search axis of all geoms
4963
struct_sort_buffer = ti.types.struct(value=gs.ti_float, i_g=gs.ti_int, is_max=gs.ti_int)
50-
self.sort_buffer = struct_sort_buffer.field(
51-
shape=f_batch(2 * n_geoms),
52-
layout=ti.Layout.SOA,
53-
)
64+
self.sort_buffer = struct_sort_buffer.field(shape=f_batch(2 * n_geoms), layout=ti.Layout.SOA)
5465

5566
# This buffer stores indexes of active geoms during SAP search
5667
if use_hibernation:
@@ -60,7 +71,6 @@ def __init__(self, solver, collider_info):
6071

6172
# Stores the validity of the collision pairs
6273
self.collision_pair_validity = ti.field(dtype=gs.ti_int, shape=(n_geoms, n_geoms))
63-
n_possible_pairs = self._init_collision_pair_validity(solver)
6474

6575
# Whether or not this is the first time to run the broad phase for each batch
6676
self.first_time = ti.field(gs.ti_int, shape=_B)
@@ -70,15 +80,9 @@ def __init__(self, solver, collider_info):
7080
self._max_collision_pairs = ti.field(dtype=gs.ti_int, shape=())
7181
self._max_contact_pairs = ti.field(dtype=gs.ti_int, shape=())
7282

73-
self._max_possible_pairs[None] = n_possible_pairs
74-
self._max_collision_pairs[None] = min(n_possible_pairs, max_collision_pairs)
75-
self._max_contact_pairs[None] = self._max_collision_pairs[None] * collider_info.n_contacts_per_pair
76-
7783
# Final results of the broad phase
7884
self.n_broad_pairs = ti.field(dtype=gs.ti_int, shape=_B)
79-
self.broad_collision_pairs = ti.Vector.field(
80-
2, dtype=gs.ti_int, shape=f_batch(max(1, self._max_collision_pairs[None]))
81-
)
85+
self.broad_collision_pairs = ti.Vector.field(2, dtype=gs.ti_int, shape=f_batch(max(1, max_collision_pairs)))
8286

8387
############## narrow phase ##############
8488
struct_contact_data = ti.types.struct(
@@ -94,32 +98,23 @@ def __init__(self, solver, collider_info):
9498
link_b=gs.ti_int,
9599
)
96100
self.contact_data = struct_contact_data.field(
97-
shape=f_batch(max(1, self._max_contact_pairs[None])),
101+
shape=f_batch(max(1, max_contact_pairs)),
98102
layout=ti.Layout.SOA,
99103
)
100104
# total number of contacts, including hibernated contacts
101105
self.n_contacts = ti.field(gs.ti_int, shape=_B)
102106
self.n_contacts_hibernated = ti.field(gs.ti_int, shape=_B)
103-
self._contacts_info_cache = {}
104107

105108
# contact caching for warmstart collision detection
106109
struct_contact_cache = ti.types.struct(
107110
# i_va_ws=gs.ti_int,
108111
# penetration=gs.ti_float,
109112
normal=gs.ti_vec3,
110113
)
111-
self.contact_cache = struct_contact_cache.field(
112-
shape=f_batch((n_geoms, n_geoms)),
113-
layout=ti.Layout.SOA,
114-
)
115-
116-
# for faster compilation
117-
if collider_info.has_terrain:
118-
self.xyz_max_min = ti.field(dtype=gs.ti_float, shape=f_batch(6))
119-
self.prism = ti.field(dtype=gs.ti_vec3, shape=f_batch(6))
114+
self.contact_cache = struct_contact_cache.field(shape=f_batch((n_geoms, n_geoms)), layout=ti.Layout.SOA)
120115

121116
########## Box-box contact detection ##########
122-
if solver._box_box_detection:
117+
if box_box_detection:
123118
# With the existing Box-Box collision detection algorithm, it is not clear where the contact points are
124119
# located depending of the pose and size of each box. In practice, up to 11 contact points have been
125120
# observed. The theoretical worst case scenario would be 2 cubes roughly the same size and same center,
@@ -138,108 +133,11 @@ def __init__(self, solver, collider_info):
138133
links_idx = solver.geoms_info.link_idx.to_numpy()[solver.geoms_info.type.to_numpy() == gs.GEOM_TYPE.TERRAIN]
139134
entity = solver._entities[solver.links_info.entity_idx.to_numpy()[links_idx[0]]]
140135

141-
scale = entity.terrain_scale.astype(gs.np_float)
142-
rc = np.array(entity.terrain_hf.shape, dtype=gs.np_int)
143-
hf = entity.terrain_hf.astype(gs.np_float) * scale[1]
144-
xyz_maxmin = np.array(
145-
[rc[0] * scale[0], rc[1] * scale[0], hf.max(), 0, 0, hf.min() - 1.0],
146-
dtype=gs.np_float,
147-
)
148-
149-
self.terrain_hf = ti.field(dtype=gs.ti_float, shape=hf.shape)
136+
self.terrain_hf = ti.field(dtype=gs.ti_float, shape=entity.terrain_hf.shape)
150137
self.terrain_rc = ti.field(dtype=gs.ti_int, shape=2)
151138
self.terrain_scale = ti.field(dtype=gs.ti_float, shape=2)
152139
self.terrain_xyz_maxmin = ti.field(dtype=gs.ti_float, shape=6)
153140

154-
self.terrain_hf.from_numpy(hf)
155-
self.terrain_rc.from_numpy(rc)
156-
self.terrain_scale.from_numpy(scale)
157-
self.terrain_xyz_maxmin.from_numpy(xyz_maxmin)
158-
159-
def _init_verts_connectivity(self, solver) -> None:
160-
"""
161-
Initialize the vertex connectivity fields.
162-
"""
163-
vert_neighbors = []
164-
vert_neighbor_start = []
165-
vert_n_neighbors = []
166-
offset = 0
167-
for geom in solver.geoms:
168-
vert_neighbors.append(geom.vert_neighbors + geom.vert_start)
169-
vert_neighbor_start.append(geom.vert_neighbor_start + offset)
170-
vert_n_neighbors.append(geom.vert_n_neighbors)
171-
offset += len(geom.vert_neighbors)
172-
173-
if solver.n_verts > 0:
174-
vert_neighbors = np.concatenate(vert_neighbors, dtype=gs.np_int)
175-
vert_neighbor_start = np.concatenate(vert_neighbor_start, dtype=gs.np_int)
176-
vert_n_neighbors = np.concatenate(vert_n_neighbors, dtype=gs.np_int)
177-
178-
self.vert_neighbors = ti.field(dtype=gs.ti_int, shape=max(1, len(vert_neighbors)))
179-
self.vert_neighbor_start = ti.field(dtype=gs.ti_int, shape=solver.n_verts_)
180-
self.vert_n_neighbors = ti.field(dtype=gs.ti_int, shape=solver.n_verts_)
181-
182-
if solver.n_verts > 0:
183-
self.vert_neighbors.from_numpy(vert_neighbors)
184-
self.vert_neighbor_start.from_numpy(vert_neighbor_start)
185-
self.vert_n_neighbors.from_numpy(vert_n_neighbors)
186-
187-
def _init_collision_pair_validity(self, solver):
188-
"""
189-
Initialize the collision pair validity matrix.
190-
191-
For each pair of geoms, determine if they can collide based on their properties and the solver configuration.
192-
"""
193-
n_geoms = solver.n_geoms_
194-
enable_self_collision = solver._static_rigid_sim_config.enable_self_collision
195-
enable_adjacent_collision = solver._static_rigid_sim_config.enable_adjacent_collision
196-
batch_links_info = solver._static_rigid_sim_config.batch_links_info
197-
198-
geoms_link_idx = solver.geoms_info.link_idx.to_numpy()
199-
geoms_contype = solver.geoms_info.contype.to_numpy()
200-
geoms_conaffinity = solver.geoms_info.conaffinity.to_numpy()
201-
links_entity_idx = solver.links_info.entity_idx.to_numpy()
202-
links_root_idx = solver.links_info.root_idx.to_numpy()
203-
links_parent_idx = solver.links_info.parent_idx.to_numpy()
204-
links_is_fixed = solver.links_info.is_fixed.to_numpy()
205-
if batch_links_info:
206-
links_entity_idx = links_entity_idx[:, 0]
207-
links_root_idx = links_root_idx[:, 0]
208-
links_parent_idx = links_parent_idx[:, 0]
209-
links_is_fixed = links_is_fixed[:, 0]
210-
211-
n_possible_pairs = 0
212-
for i_ga in range(n_geoms):
213-
for i_gb in range(i_ga + 1, n_geoms):
214-
i_la = geoms_link_idx[i_ga]
215-
i_lb = geoms_link_idx[i_gb]
216-
217-
# geoms in the same link
218-
if i_la == i_lb:
219-
continue
220-
221-
# self collision
222-
if links_root_idx[i_la] == links_root_idx[i_lb]:
223-
if not enable_self_collision:
224-
continue
225-
226-
# adjacent links
227-
if not enable_adjacent_collision and (
228-
links_parent_idx[i_la] == i_lb or links_parent_idx[i_lb] == i_la
229-
):
230-
continue
231-
232-
# contype and conaffinity
233-
if links_entity_idx[i_la] == links_entity_idx[i_lb] and not (
234-
(geoms_contype[i_ga] & geoms_conaffinity[i_gb]) or (geoms_contype[i_gb] & geoms_conaffinity[i_ga])
235-
):
236-
continue
237-
238-
# pair of fixed links wrt the world
239-
if links_is_fixed[i_la] and links_is_fixed[i_lb]:
240-
continue
241-
242-
self.collision_pair_validity[i_ga, i_gb] = 1
243-
n_possible_pairs += 1
244-
245-
return n_possible_pairs
141+
# for faster compilation
142+
self.xyz_max_min = ti.field(dtype=gs.ti_float, shape=f_batch(6))
143+
self.prism = ti.field(dtype=gs.ti_vec3, shape=f_batch(6))

0 commit comments

Comments
 (0)