Skip to content

Commit 7f84304

Browse files
authored
Add support for plane vs convex (#107)
* Plane vs convex collision starts to work * Ran ruff format * Fix argument lists * Add safe guards to avoid contact overflow * Adapt new argument lists in collision_primitive.py * Remove pot only scene * Address MR comments from github * Fix accessing the model struct in a method where it is no longer accessible * Change argument order
1 parent dc2d4d9 commit 7f84304

File tree

4 files changed

+218
-5
lines changed

4 files changed

+218
-5
lines changed

mujoco_warp/_src/collision_convex.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -781,21 +781,25 @@ def gjk_epa_sparse(
781781
return
782782

783783
geom1 = _geom(
784+
geom_type,
784785
geom_dataid,
785786
geom_size,
786787
mesh_vertadr,
787788
mesh_vertnum,
789+
mesh_vert,
788790
geom_xpos_in,
789791
geom_xmat_in,
790792
worldid,
791793
g1,
792794
)
793795

794796
geom2 = _geom(
797+
geom_type,
795798
geom_dataid,
796799
geom_size,
797800
mesh_vertadr,
798801
mesh_vertnum,
802+
mesh_vert,
799803
geom_xpos_in,
800804
geom_xmat_in,
801805
worldid,

mujoco_warp/_src/collision_driver_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,20 @@ class CollisionTest(parameterized.TestCase):
202202
</worldbody>
203203
</mujoco>
204204
""",
205+
"mesh_plane": """
206+
<mujoco>
207+
<asset>
208+
<mesh name="cube" vertex="1 1 1 1 1 -1 1 -1 1 1 -1 -1 -1 1 1 -1 1 -1 -1 -1 1 -1 -1 -1"/>
209+
</asset>
210+
<worldbody>
211+
<geom size="40 40 40" type="plane"/>
212+
<body pos="0 0 1" euler="45 0 0">
213+
<freejoint/>
214+
<geom type="mesh" mesh="cube"/>
215+
</body>
216+
</worldbody>
217+
</mujoco>
218+
""",
205219
"sphere_box_shallow": """
206220
<mujoco>
207221
<worldbody>

mujoco_warp/_src/collision_primitive.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,18 @@ class Geom:
3636
size: wp.vec3
3737
vertadr: int
3838
vertnum: int
39+
vert: wp.array(dtype=wp.vec3)
3940

4041

4142
@wp.func
4243
def _geom(
4344
# Model:
45+
geom_type: wp.array(dtype=int),
4446
geom_dataid: wp.array(dtype=int),
4547
geom_size: wp.array(dtype=wp.vec3),
4648
mesh_vertadr: wp.array(dtype=int),
4749
mesh_vertnum: wp.array(dtype=int),
50+
mesh_vert: wp.array(dtype=wp.vec3),
4851
# Data in:
4952
geom_xpos_in: wp.array2d(dtype=wp.vec3),
5053
geom_xmat_in: wp.array2d(dtype=wp.mat33),
@@ -67,6 +70,9 @@ def _geom(
6770
geom.vertadr = -1
6871
geom.vertnum = -1
6972

73+
if geom_type[gid] == int(GeomType.MESH.value):
74+
geom.vert = mesh_vert
75+
7076
return geom
7177

7278

@@ -724,6 +730,152 @@ def plane_box(
724730
break
725731

726732

733+
_HUGE_VAL = 1e6
734+
735+
736+
@wp.func
737+
def plane_convex(
738+
# Data in:
739+
nconmax_in: int,
740+
# In:
741+
plane: Geom,
742+
convex: Geom,
743+
worldid: int,
744+
margin: float,
745+
gap: float,
746+
condim: int,
747+
friction: vec5,
748+
solref: wp.vec2f,
749+
solreffriction: wp.vec2f,
750+
solimp: vec5,
751+
geoms: wp.vec2i,
752+
# Data out:
753+
ncon_out: wp.array(dtype=int),
754+
contact_dist_out: wp.array(dtype=float),
755+
contact_pos_out: wp.array(dtype=wp.vec3),
756+
contact_frame_out: wp.array(dtype=wp.mat33),
757+
contact_includemargin_out: wp.array(dtype=float),
758+
contact_friction_out: wp.array(dtype=vec5),
759+
contact_solref_out: wp.array(dtype=wp.vec2),
760+
contact_solreffriction_out: wp.array(dtype=wp.vec2),
761+
contact_solimp_out: wp.array(dtype=vec5),
762+
contact_dim_out: wp.array(dtype=int),
763+
contact_geom_out: wp.array(dtype=wp.vec2i),
764+
contact_worldid_out: wp.array(dtype=int),
765+
):
766+
"""Calculates contacts between a plane and a convex object."""
767+
768+
# get points in the convex frame
769+
plane_pos = wp.transpose(convex.rot) @ (plane.pos - convex.pos)
770+
n = wp.transpose(convex.rot) @ plane.normal
771+
772+
# Find support points
773+
max_support = wp.float32(-_HUGE_VAL)
774+
for i in range(convex.vertnum):
775+
support = wp.dot(plane_pos - convex.vert[convex.vertadr + i], n)
776+
777+
max_support = wp.max(support, max_support)
778+
779+
threshold = wp.max(0.0, max_support - 1e-3)
780+
781+
# Store indices in vec4
782+
indices = wp.vec4i(-1, -1, -1, -1)
783+
784+
# TODO(team): Explore faster methods like tile_min or even fast pass kernels if the upper bound of vertices in all convexes is small enough such that all vertices fit into shared memory
785+
# Find point a (first support point)
786+
a_dist = wp.float32(-_HUGE_VAL)
787+
for i in range(convex.vertnum):
788+
support = wp.dot(plane_pos - convex.vert[convex.vertadr + i], n)
789+
dist = wp.where(support > threshold, 0.0, -_HUGE_VAL)
790+
if dist > a_dist:
791+
indices[0] = i
792+
a_dist = dist
793+
a = convex.vert[convex.vertadr + indices[0]]
794+
795+
# Find point b (furthest from a)
796+
b_dist = wp.float32(-_HUGE_VAL)
797+
for i in range(convex.vertnum):
798+
support = wp.dot(plane_pos - convex.vert[convex.vertadr + i], n)
799+
dist_mask = wp.where(support > threshold, 0.0, -_HUGE_VAL)
800+
dist = wp.length_sq(a - convex.vert[convex.vertadr + i]) + dist_mask
801+
if dist > b_dist:
802+
indices[1] = i
803+
b_dist = dist
804+
b = convex.vert[convex.vertadr + indices[1]]
805+
806+
# Find point c (furthest along axis orthogonal to a-b)
807+
ab = wp.cross(n, a - b)
808+
c_dist = wp.float32(-_HUGE_VAL)
809+
for i in range(convex.vertnum):
810+
support = wp.dot(plane_pos - convex.vert[convex.vertadr + i], n)
811+
dist_mask = wp.where(support > threshold, 0.0, -_HUGE_VAL)
812+
ap = a - convex.vert[convex.vertadr + i]
813+
dist = wp.abs(wp.dot(ap, ab)) + dist_mask
814+
if dist > c_dist:
815+
indices[2] = i
816+
c_dist = dist
817+
c = convex.vert[convex.vertadr + indices[2]]
818+
819+
# Find point d (furthest from other triangle edges)
820+
ac = wp.cross(n, a - c)
821+
bc = wp.cross(n, b - c)
822+
d_dist = wp.float32(-_HUGE_VAL)
823+
for i in range(convex.vertnum):
824+
support = wp.dot(plane_pos - convex.vert[convex.vertadr + i], n)
825+
dist_mask = wp.where(support > threshold, 0.0, -_HUGE_VAL)
826+
ap = a - convex.vert[convex.vertadr + i]
827+
bp = b - convex.vert[convex.vertadr + i]
828+
dist_ap = wp.abs(wp.dot(ap, ac)) + dist_mask
829+
dist_bp = wp.abs(wp.dot(bp, bc)) + dist_mask
830+
if dist_ap + dist_bp > d_dist:
831+
indices[3] = i
832+
d_dist = dist_ap + dist_bp
833+
834+
# Write contacts
835+
frame = make_frame(plane.normal)
836+
for i in range(3, -1, -1):
837+
idx = indices[i]
838+
count = int(0)
839+
for j in range(i + 1):
840+
if indices[j] == idx:
841+
count = count + 1
842+
843+
# Check if the index is unique (appears exactly once)
844+
if count == 1:
845+
pos = convex.vert[convex.vertadr + idx]
846+
pos = convex.pos + convex.rot @ pos
847+
support = wp.dot(plane_pos - convex.vert[convex.vertadr + idx], n)
848+
dist = -support
849+
pos = pos - 0.5 * dist * plane.normal
850+
write_contact(
851+
nconmax_in,
852+
dist,
853+
pos,
854+
frame,
855+
margin,
856+
gap,
857+
condim,
858+
friction,
859+
solref,
860+
solreffriction,
861+
solimp,
862+
geoms,
863+
worldid,
864+
ncon_out,
865+
contact_dist_out,
866+
contact_pos_out,
867+
contact_frame_out,
868+
contact_includemargin_out,
869+
contact_friction_out,
870+
contact_solref_out,
871+
contact_solreffriction_out,
872+
contact_solimp_out,
873+
contact_dim_out,
874+
contact_geom_out,
875+
contact_worldid_out,
876+
)
877+
878+
727879
@wp.func
728880
def sphere_cylinder(
729881
# Data in:
@@ -1734,6 +1886,7 @@ def _primitive_narrowphase(
17341886
geom_gap: wp.array(dtype=float),
17351887
mesh_vertadr: wp.array(dtype=int),
17361888
mesh_vertnum: wp.array(dtype=int),
1889+
mesh_vert: wp.array(dtype=wp.vec3),
17371890
pair_dim: wp.array(dtype=int),
17381891
pair_solref: wp.array(dtype=wp.vec2),
17391892
pair_solreffriction: wp.array(dtype=wp.vec2),
@@ -1794,20 +1947,24 @@ def _primitive_narrowphase(
17941947
worldid = collision_worldid_in[tid]
17951948

17961949
geom1 = _geom(
1950+
geom_type,
17971951
geom_dataid,
17981952
geom_size,
17991953
mesh_vertadr,
18001954
mesh_vertnum,
1955+
mesh_vert,
18011956
geom_xpos_in,
18021957
geom_xmat_in,
18031958
worldid,
18041959
g1,
18051960
)
18061961
geom2 = _geom(
1962+
geom_type,
18071963
geom_dataid,
18081964
geom_size,
18091965
mesh_vertadr,
18101966
mesh_vertnum,
1967+
mesh_vert,
18111968
geom_xpos_in,
18121969
geom_xmat_in,
18131970
worldid,
@@ -1953,6 +2110,33 @@ def _primitive_narrowphase(
19532110
contact_geom_out,
19542111
contact_worldid_out,
19552112
)
2113+
elif type1 == int(GeomType.PLANE.value) and type2 == int(GeomType.MESH.value):
2114+
plane_convex(
2115+
nconmax_in,
2116+
geom1,
2117+
geom2,
2118+
worldid,
2119+
margin,
2120+
gap,
2121+
condim,
2122+
friction,
2123+
solref,
2124+
solreffriction,
2125+
solimp,
2126+
geoms,
2127+
ncon_out,
2128+
contact_dist_out,
2129+
contact_pos_out,
2130+
contact_frame_out,
2131+
contact_includemargin_out,
2132+
contact_friction_out,
2133+
contact_solref_out,
2134+
contact_solreffriction_out,
2135+
contact_solimp_out,
2136+
contact_dim_out,
2137+
contact_geom_out,
2138+
contact_worldid_out,
2139+
)
19562140
elif type1 == int(GeomType.SPHERE.value) and type2 == int(GeomType.CAPSULE.value):
19572141
sphere_capsule(
19582142
nconmax_in,
@@ -2110,6 +2294,7 @@ def primitive_narrowphase(m: Model, d: Data):
21102294
m.geom_gap,
21112295
m.mesh_vertadr,
21122296
m.mesh_vertnum,
2297+
m.mesh_vert,
21132298
m.pair_dim,
21142299
m.pair_solref,
21152300
m.pair_solreffriction,

mujoco_warp/_src/io.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,11 +1108,21 @@ def get_data_into(
11081108
result.xfrc_applied[:] = d.xfrc_applied.numpy()[0]
11091109
result.eq_active[:] = d.eq_active.numpy()[0]
11101110

1111-
result.efc_D[:] = d.efc.D.numpy()[:nefc]
1112-
result.efc_pos[:] = d.efc.pos.numpy()[:nefc]
1113-
result.efc_aref[:] = d.efc.aref.numpy()[:nefc]
1114-
result.efc_force[:] = d.efc.force.numpy()[:nefc]
1115-
result.efc_margin[:] = d.efc.margin.numpy()[:nefc]
1111+
# TODO(team): set these efc_* fields after fix to _realloc_con_efc
1112+
# Safely copy only up to the minimum of the destination and source sizes
1113+
# n = min(result.efc_D.shape[0], d.efc.D.numpy()[:nefc].shape[0])
1114+
# result.efc_D[:n] = d.efc.D.numpy()[:nefc][:n]
1115+
# n_pos = min(result.efc_pos.shape[0], d.efc.pos.numpy()[:nefc].shape[0])
1116+
# result.efc_pos[:n_pos] = d.efc.pos.numpy()[:nefc][:n_pos]
1117+
1118+
# n_aref = min(result.efc_aref.shape[0], d.efc.aref.numpy()[:nefc].shape[0])
1119+
# result.efc_aref[:n_aref] = d.efc.aref.numpy()[:nefc][:n_aref]
1120+
1121+
# n_force = min(result.efc_force.shape[0], d.efc.force.numpy()[:nefc].shape[0])
1122+
# result.efc_force[:n_force] = d.efc.force.numpy()[:nefc][:n_force]
1123+
1124+
# n_margin = min(result.efc_margin.shape[0], d.efc.margin.numpy()[:nefc].shape[0])
1125+
# result.efc_margin[:n_margin] = d.efc.margin.numpy()[:nefc][:n_margin]
11161126

11171127
result.cacc[:] = d.cacc.numpy()[0]
11181128
result.cfrc_int[:] = d.cfrc_int.numpy()[0]

0 commit comments

Comments
 (0)