Skip to content

Feature/71 height fields #216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
44d09c9
Height fields #71
vreutskyy May 5, 2025
f23bc12
Merge branch 'main' into feature/71-height-fields
vreutskyy May 5, 2025
593b0e3
some clean up and refactoring
vreutskyy May 5, 2025
d1d403d
added a UT. disabled another UT
vreutskyy May 5, 2025
c6105d0
Merge branch 'main' into feature/71-height-fields
vreutskyy May 6, 2025
be69586
fixed import sections with isort
vreutskyy May 6, 2025
53a93ab
fixed import section with ruff
vreutskyy May 6, 2025
dbf5563
fixed formatting with ruff
vreutskyy May 6, 2025
96c03c1
fixed formatting with ruff
vreutskyy May 6, 2025
f4c7160
fixed formatting with ruff
vreutskyy May 6, 2025
65a3762
UT fix
vreutskyy May 6, 2025
437fd15
UT fix
vreutskyy May 6, 2025
df7b99d
formatting fix
vreutskyy May 6, 2025
77d7d02
Segmentation fault fix
vreutskyy May 6, 2025
cc37655
Small fixes
vreutskyy May 7, 2025
cefbf74
formatting
vreutskyy May 7, 2025
33f594d
small fixes
vreutskyy May 7, 2025
97d3586
small fixes
vreutskyy May 7, 2025
e530fad
small fixes
vreutskyy May 7, 2025
4a322cf
Merge branch 'google-deepmind:main' into feature/71-height-fields
vreutskyy May 7, 2025
bf3ae94
Merge branch 'main' into feature/71-height-fields
vreutskyy May 12, 2025
29d049d
re-implemented hfield collision after API change
vreutskyy May 13, 2025
3a7f6be
Merge branch 'main' into feature/71-height-fields
vreutskyy May 13, 2025
af0aa00
formatting
vreutskyy May 13, 2025
46346dd
formatting
vreutskyy May 13, 2025
71c6060
formatting
vreutskyy May 13, 2025
d6fae7d
a comment typo
vreutskyy May 14, 2025
941f63c
removing a useless condition
vreutskyy May 14, 2025
4477abf
Merge branch 'main' into feature/71-height-fields
vreutskyy May 15, 2025
40cc0b1
Merge branch 'main' into feature/71-height-fields
vreutskyy May 15, 2025
60a138c
Merge branch 'main' into feature/71-height-fields
vreutskyy May 19, 2025
1b7af22
Merge branch 'main' into feature/71-height-fields
vreutskyy May 19, 2025
b94a549
Merge branch 'main' into feature/71-height-fields
vreutskyy May 20, 2025
20cb228
small fixes
vreutskyy May 20, 2025
b7d6df1
a fix for GJK/EPA and the restored UT
vreutskyy May 20, 2025
509011c
added a TODO for capping max contacts
vreutskyy May 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions mujoco_warp/_src/collision_convex.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import warp as wp

from .collision_hfield import get_hfield_prism_vertex
from .collision_primitive import Geom
from .collision_primitive import _geom
from .collision_primitive import contact_params
Expand Down Expand Up @@ -99,6 +100,15 @@ def _gjk_support_geom(geom: Geom, geomtype: int, dir: wp.vec3, verts: wp.array(d
max_dist = dist
support_pt = vert
support_pt = geom.rot @ support_pt + geom.pos
elif geomtype == int(GeomType.HFIELD.value):
max_dist = float(FLOAT_MIN)
for i in range(6):
vert = get_hfield_prism_vertex(geom.hfprism, i)
dist = wp.dot(vert, local_dir)
if dist > max_dist:
max_dist = dist
support_pt = vert
support_pt = geom.rot @ support_pt + geom.pos

return wp.dot(support_pt, dir), support_pt

Expand All @@ -125,6 +135,12 @@ def _gjk_support(


_CONVEX_COLLISION_FUNC = {
(GeomType.HFIELD.value, GeomType.SPHERE.value),
(GeomType.HFIELD.value, GeomType.CAPSULE.value),
(GeomType.HFIELD.value, GeomType.ELLIPSOID.value),
(GeomType.HFIELD.value, GeomType.CYLINDER.value),
(GeomType.HFIELD.value, GeomType.BOX.value),
(GeomType.HFIELD.value, GeomType.MESH.value),
(GeomType.SPHERE.value, GeomType.ELLIPSOID.value),
(GeomType.SPHERE.value, GeomType.MESH.value),
(GeomType.CAPSULE.value, GeomType.CYLINDER.value),
Expand Down Expand Up @@ -716,6 +732,11 @@ def gjk_epa_sparse(
geom_friction: wp.array2d(dtype=wp.vec3),
geom_margin: wp.array2d(dtype=float),
geom_gap: wp.array2d(dtype=float),
hfield_adr: wp.array(dtype=int),
hfield_nrow: wp.array(dtype=int),
hfield_ncol: wp.array(dtype=int),
hfield_size: wp.array(dtype=wp.vec4),
hfield_data: wp.array(dtype=float),
mesh_vertadr: wp.array(dtype=int),
mesh_vertnum: wp.array(dtype=int),
mesh_vert: wp.array(dtype=wp.vec3),
Expand All @@ -731,6 +752,7 @@ def gjk_epa_sparse(
geom_xpos_in: wp.array2d(dtype=wp.vec3),
geom_xmat_in: wp.array2d(dtype=wp.mat33),
collision_pair_in: wp.array(dtype=wp.vec2i),
collision_hftri_index_in: wp.array(dtype=int),
collision_pairid_in: wp.array(dtype=int),
collision_worldid_in: wp.array(dtype=int),
ncollision_in: wp.array(dtype=int),
Expand Down Expand Up @@ -781,30 +803,44 @@ def gjk_epa_sparse(
if geom_type[g1] != geomtype1 or geom_type[g2] != geomtype2:
return

hftri_index = collision_hftri_index_in[tid]

geom1 = _geom(
geom_type,
geom_dataid,
geom_size[worldid],
geom_size,
hfield_adr,
hfield_nrow,
hfield_ncol,
hfield_size,
hfield_data,
mesh_vertadr,
mesh_vertnum,
mesh_vert,
geom_xpos_in,
geom_xmat_in,
worldid,
g1,
hftri_index,
)

geom2 = _geom(
geom_type,
geom_dataid,
geom_size[worldid],
geom_size,
hfield_adr,
hfield_nrow,
hfield_ncol,
hfield_size,
hfield_data,
mesh_vertadr,
mesh_vertnum,
mesh_vert,
geom_xpos_in,
geom_xmat_in,
worldid,
g2,
hftri_index,
)

margin = wp.max(geom_margin[worldid, g1], geom_margin[worldid, g2])
Expand Down Expand Up @@ -888,6 +924,11 @@ def gjk_narrowphase(m: Model, d: Data):
m.geom_friction,
m.geom_margin,
m.geom_gap,
m.hfield_adr,
m.hfield_nrow,
m.hfield_ncol,
m.hfield_size,
m.hfield_data,
m.mesh_vertadr,
m.mesh_vertnum,
m.mesh_vert,
Expand All @@ -902,6 +943,7 @@ def gjk_narrowphase(m: Model, d: Data):
d.geom_xpos,
d.geom_xmat,
d.collision_pair,
d.collision_hftri_index,
d.collision_pairid,
d.collision_worldid,
d.ncollision,
Expand Down
22 changes: 22 additions & 0 deletions mujoco_warp/_src/collision_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
import warp as wp

from .collision_convex import gjk_narrowphase
from .collision_hfield import hfield_midphase
from .collision_primitive import primitive_narrowphase
from .types import MJ_MAXVAL
from .types import Data
from .types import DisableBit
from .types import GeomType
from .types import Model
from .warp_util import event_scope

Expand Down Expand Up @@ -81,6 +83,7 @@ def _add_geom_pair(
nxnid: int,
# Data out:
collision_pair_out: wp.array(dtype=wp.vec2i),
collision_hftri_index_out: wp.array(dtype=int),
collision_pairid_out: wp.array(dtype=int),
collision_worldid_out: wp.array(dtype=int),
ncollision_out: wp.array(dtype=int),
Expand All @@ -102,6 +105,12 @@ def _add_geom_pair(
collision_pairid_out[pairid] = nxn_pairid[nxnid]
collision_worldid_out[pairid] = worldid

# Writing -1 to collision_hftri_index_out[pairid] signals
# hfield_midphase to generate a collision pair for every
# potentially colliding triangle
if type1 == int(GeomType.HFIELD.value) or type2 == int(GeomType.HFIELD.value):
collision_hftri_index_out[pairid] = -1


@wp.func
def _binary_search(values: wp.array(dtype=Any), value: Any, lower: int, upper: int) -> int:
Expand Down Expand Up @@ -195,6 +204,7 @@ def _sap_broadphase(
nsweep_in: int,
# Data out:
collision_pair_out: wp.array(dtype=wp.vec2i),
collision_hftri_index_out: wp.array(dtype=int),
collision_pairid_out: wp.array(dtype=int),
collision_worldid_out: wp.array(dtype=int),
ncollision_out: wp.array(dtype=int),
Expand Down Expand Up @@ -248,6 +258,7 @@ def _sap_broadphase(
worldid,
idx,
collision_pair_out,
collision_hftri_index_out,
collision_pairid_out,
collision_worldid_out,
ncollision_out,
Expand Down Expand Up @@ -330,6 +341,7 @@ def sap_broadphase(m: Model, d: Data):
],
outputs=[
d.collision_pair,
d.collision_hftri_index,
d.collision_pairid,
d.collision_worldid,
d.ncollision,
Expand All @@ -351,6 +363,7 @@ def _nxn_broadphase(
geom_xmat_in: wp.array2d(dtype=wp.mat33),
# Data out:
collision_pair_out: wp.array(dtype=wp.vec2i),
collision_hftri_index_out: wp.array(dtype=int),
collision_pairid_out: wp.array(dtype=int),
collision_worldid_out: wp.array(dtype=int),
ncollision_out: wp.array(dtype=int),
Expand Down Expand Up @@ -383,6 +396,7 @@ def _nxn_broadphase(
worldid,
elementid,
collision_pair_out,
collision_hftri_index_out,
collision_pairid_out,
collision_worldid_out,
ncollision_out,
Expand All @@ -408,6 +422,7 @@ def nxn_broadphase(m: Model, d: Data):
],
outputs=[
d.collision_pair,
d.collision_hftri_index,
d.collision_pairid,
d.collision_worldid,
d.ncollision,
Expand All @@ -426,6 +441,9 @@ def collision(m: Model, d: Data):
d.ncollision.zero_()
d.ncon.zero_()

# Clear the collision_hftri_index buffer
d.collision_hftri_index.zero_()

if d.nconmax == 0:
return

Expand All @@ -439,6 +457,10 @@ def collision(m: Model, d: Data):
else:
sap_broadphase(m, d)

# Process heightfield collisions
if m.nhfield > 0:
hfield_midphase(m, d)

# TODO(team): we should reject far-away contacts in the narrowphase instead of constraint
# partitioning because we can move some pressure of the atomics
# TODO(team) switch between collision functions and GJK/EPA here
Expand Down
29 changes: 29 additions & 0 deletions mujoco_warp/_src/collision_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,35 @@ def test_collision(self, fixture):
if not allow_different_contact_count:
self.assertEqual(d.ncon.numpy()[0], mjd.ncon)

_HFIELD_FIXTURES = {
"hfield_box": """
<mujoco>
<asset>
<hfield name="terrain" nrow="2" ncol="2" size="1 1 0.1 0.1"
elevation="0 0
0 0"/>
</asset>
<worldbody>
<geom type="hfield" hfield="terrain" pos="0 0 0"/>
<body pos=".0 .0 .1">
<freejoint/>
<geom type="box" size=".1 .1 .11"/>
</body>
</worldbody>
</mujoco>
""",
}

@parameterized.parameters(_HFIELD_FIXTURES.keys())
def test_hfield_collision(self, fixture):
"""Tests hfield collision with different geometries."""
mjm, mjd, m, d = test_util.fixture(xml=self._HFIELD_FIXTURES[fixture])

mujoco.mj_collision(mjm, mjd)
mjwarp.collision(m, d)

self.assertEqual(mjd.ncon > 0, d.ncon.numpy()[0] > 0, "If MJ collides, MJW should too")

def test_contact_exclude(self):
"""Tests contact exclude."""
_, _, m, _ = test_util.fixture(
Expand Down
Loading
Loading