Skip to content

Feature/71 height fields #216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions mujoco_warp/_src/collision_convex.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import warp as wp

from .collision_hfield import get_hfield_prism_vertex
from .collision_primitive import Geom
from .collision_primitive import _geom
from .collision_primitive import contact_params
Expand Down Expand Up @@ -105,6 +106,15 @@ def _gjk_support_geom(
max_dist = dist
support_pt = vert
support_pt = geom.rot @ support_pt + geom.pos
elif geom_type == int(GeomType.HFIELD.value):
max_dist = float(FLOAT_MIN)
for i in range(6):
vert = get_hfield_prism_vertex(geom.hfprism, i)
dist = wp.dot(vert, local_dir)
if dist > max_dist:
max_dist = dist
support_pt = vert
support_pt = geom.rot @ support_pt + geom.pos

return wp.dot(support_pt, dir), support_pt

Expand All @@ -130,6 +140,12 @@ def _gjk_support(


_CONVEX_COLLISION_FUNC = {
(GeomType.HFIELD.value, GeomType.SPHERE.value),
(GeomType.HFIELD.value, GeomType.CAPSULE.value),
(GeomType.HFIELD.value, GeomType.ELLIPSOID.value),
(GeomType.HFIELD.value, GeomType.CYLINDER.value),
(GeomType.HFIELD.value, GeomType.BOX.value),
(GeomType.HFIELD.value, GeomType.MESH.value),
(GeomType.SPHERE.value, GeomType.ELLIPSOID.value),
(GeomType.SPHERE.value, GeomType.MESH.value),
(GeomType.CAPSULE.value, GeomType.CYLINDER.value),
Expand Down Expand Up @@ -763,8 +779,10 @@ def _gjk_epa_sparse(m: Model, d: Data):
if m.geom_type[g1] != geomtype1 or m.geom_type[g2] != geomtype2:
return

geom1 = _geom(g1, m, d.geom_xpos[worldid], d.geom_xmat[worldid])
geom2 = _geom(g2, m, d.geom_xpos[worldid], d.geom_xmat[worldid])
index = d.collision_index[tid]

geom1 = _geom(g1, m, d.geom_xpos[worldid], d.geom_xmat[worldid], index)
geom2 = _geom(g2, m, d.geom_xpos[worldid], d.geom_xmat[worldid], index)

margin = wp.max(m.geom_margin[g1], m.geom_margin[g2])

Expand Down
48 changes: 40 additions & 8 deletions mujoco_warp/_src/collision_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

from .collision_box import box_box_narrowphase
from .collision_convex import gjk_narrowphase
from .collision_hfield import get_hfield_overlap_range
from .collision_primitive import primitive_narrowphase
from .types import MJ_MAXVAL
from .types import MJ_MINVAL
from .types import Data
from .types import DisableBit
from .types import GeomType
from .types import Model
from .warp_util import event_scope

Expand Down Expand Up @@ -60,11 +62,7 @@ def _sphere_filter(m: Model, d: Data, geom1: int, geom2: int, worldid: int) -> b

@wp.func
def _add_geom_pair(m: Model, d: Data, geom1: int, geom2: int, worldid: int, nxnid: int):
pairid = wp.atomic_add(d.ncollision, 0, 1)

if pairid >= d.nconmax:
return

nxn_pairid = m.nxn_pairid[nxnid]
type1 = m.geom_type[geom1]
type2 = m.geom_type[geom2]

Expand All @@ -73,9 +71,43 @@ def _add_geom_pair(m: Model, d: Data, geom1: int, geom2: int, worldid: int, nxni
else:
pair = wp.vec2i(geom1, geom2)

d.collision_pair[pairid] = pair
d.collision_pairid[pairid] = m.nxn_pairid[nxnid]
d.collision_worldid[pairid] = worldid
# For height field, add a collision pair for every
# triangle that can potentially collide
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will probably need to check performance here.

Adding a collision pair for every vertex may lead to large number of collision pairs, which would slow down the code and requires large amount of memory.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For each triangle, not each vertex. I create 2 pairs/triangles for each HF cell. If you don't have a HF with absurdly small cells, a body of an average size will collide with up to 4 cells - 8 triangles. Each pair is just 2 integers (plus I added an integer for tri index), and they are already pre-allocated, it's not like I allocate them when I add a pair. For the performance, for GPU it doesn't matter if it's 8 or 800 - they are all computed simultaneously. And, after all, what are the alternatives? Would like to hear them )

if type1 == int(GeomType.HFIELD.value) or type2 == int(GeomType.HFIELD.value):
hfield = geom1
other = geom2
if type1 != int(GeomType.HFIELD.value):
hfield = geom2
other = geom1

# Get min/max grid coordinates for overlap region
min_i, min_j, max_i, max_j = get_hfield_overlap_range(m, d, hfield, other, worldid)

# Get hfield dimensions for triangle index calculation
dataid = m.geom_dataid[hfield]
ncol = m.hfield_ncol[dataid]

# Loop through grid cells and add pairs for all triangles
for j in range(min_j, max_j + 1):
for i in range(min_i, max_i + 1):
for t in range(2):
pairid = wp.atomic_add(d.ncollision, 0, 1)
if pairid >= d.nconmax:
return

d.collision_pair[pairid] = pair
d.collision_index[pairid] = ((j * (ncol - 1)) + i) * 2 + t
d.collision_pairid[pairid] = nxn_pairid
d.collision_worldid[pairid] = worldid
else:
pairid = wp.atomic_add(d.ncollision, 0, 1)
if pairid >= d.nconmax:
return

d.collision_pair[pairid] = pair
d.collision_index[pairid] = -1
d.collision_pairid[pairid] = nxn_pairid
d.collision_worldid[pairid] = worldid


@wp.func
Expand Down
31 changes: 31 additions & 0 deletions mujoco_warp/_src/collision_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,37 @@ def test_collision(self, fixture):
if not allow_different_contact_count:
self.assertEqual(d.ncon.numpy()[0], mjd.ncon)

_HFIELD_FIXTURES = {
"hfield_box": """
<mujoco>
<asset>
<hfield name="terrain" nrow="2" ncol="2" size="1 1 0.1 0.1"
elevation="0 0
0 0"/>
</asset>
<worldbody>
<geom type="hfield" hfield="terrain" pos="0 0 0"/>
<body pos=".0 .0 .1">
<freejoint/>
<geom type="box" size=".1 .1 .11"/>
</body>
</worldbody>
</mujoco>
""",
}

@parameterized.parameters(_HFIELD_FIXTURES.keys())
def test_hfield_collision(self, fixture):
"""Tests hfield collision with different geometries."""
mjm, mjd, m, d = test_util.fixture(xml=self._HFIELD_FIXTURES[fixture])

mujoco.mj_collision(mjm, mjd)
mjwarp.collision(m, d)

self.assertEqual(
mjd.ncon > 0, d.ncon.numpy()[0] > 0, "If MJ collides, MJW should too"
)

def test_contact_exclude(self):
"""Tests contact exclude."""
mjm = mujoco.MjModel.from_xml_string("""
Expand Down
197 changes: 197 additions & 0 deletions mujoco_warp/_src/collision_hfield.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Copyright 2025 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


import math
from typing import Any

import warp as wp

from .types import Data
from .types import GeomType
from .types import Model
from .warp_util import event_scope


@wp.func
def get_hfield_overlap_range(
m: Model, d: Data, hfield_geom: int, other_geom: int, worldid: int
):
"""Returns min/max grid coordinates of height field cells overlapped by other geom's bounds.

Args:
m: Model containing geometry data
d: Data containing current state
hfield_geom: Index of the height field geometry
other_geom: Index of the other geometry
worldid: Current world index

Returns:
min_i, min_j, max_i, max_j: Grid coordinate bounds
"""
# Get height field dimensions
dataid = m.geom_dataid[hfield_geom]
nrow = m.hfield_nrow[dataid]
ncol = m.hfield_ncol[dataid]
size = m.hfield_size[dataid] # (x, y, z_top, z_bottom)

# Get positions and transforms
hf_pos = d.geom_xpos[worldid, hfield_geom]
hf_mat = d.geom_xmat[worldid, hfield_geom]
other_pos = d.geom_xpos[worldid, other_geom]

# Transform other_pos to height field local space
rel_pos = other_pos - hf_pos
local_x = wp.dot(wp.vec3(hf_mat[0, 0], hf_mat[1, 0], hf_mat[2, 0]), rel_pos)
local_y = wp.dot(wp.vec3(hf_mat[0, 1], hf_mat[1, 1], hf_mat[2, 1]), rel_pos)
local_z = wp.dot(wp.vec3(hf_mat[0, 2], hf_mat[1, 2], hf_mat[2, 2]), rel_pos)
local_pos = wp.vec3(local_x, local_y, local_z)

# Get bounding radius of other geometry (including margin)
other_rbound = m.geom_rbound[other_geom]
other_margin = m.geom_margin[other_geom]
bound_radius = other_rbound + other_margin

# Calculate grid resolution
x_scale = 2.0 * size[0] / wp.float32(ncol - 1)
y_scale = 2.0 * size[1] / wp.float32(nrow - 1)

# Calculate min/max grid coordinates that could contain the object
min_i = wp.max(0, wp.int32((local_pos[0] - bound_radius + size[0]) / x_scale))
max_i = wp.min(
ncol - 2, wp.int32((local_pos[0] + bound_radius + size[0]) / x_scale) + 1
)
min_j = wp.max(0, wp.int32((local_pos[1] - bound_radius + size[1]) / y_scale))
max_j = wp.min(
nrow - 2, wp.int32((local_pos[1] + bound_radius + size[1]) / y_scale) + 1
)

return min_i, min_j, max_i, max_j


@wp.func
def get_hfield_triangle_prism(m: Model, hfieldid: int, tri_index: int) -> wp.mat33:
"""Returns the vertices of a triangular prism for a heightfield triangle.

Args:
m: Model containing geometry data
hfieldid: Index of the height field geometry
tri_index: Index of the triangle in the heightfield

Returns:
3x3 matrix containing the vertices of the triangular prism
"""
# See https://mujoco.readthedocs.io/en/stable/XMLreference.html#asset-hfield

# Get heightfield dimensions
dataid = m.geom_dataid[hfieldid]
if dataid < 0 or tri_index < 0:
return wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)

nrow = m.hfield_nrow[dataid]
ncol = m.hfield_ncol[dataid]
size = m.hfield_size[dataid] # (x, y, z_top, z_bottom)

# Calculate which triangle in the grid
row = (tri_index // 2) // (ncol - 1)
col = (tri_index // 2) % (ncol - 1)

# Calculate vertices in 2D grid
x_scale = 2.0 * size[0] / wp.float32(ncol - 1)
y_scale = 2.0 * size[1] / wp.float32(nrow - 1)

# Grid coordinates (i, j) for triangle corners
i0 = col
j0 = row
i1 = i0 + 1
j1 = j0 + 1

# Convert grid coordinates to local space x, y coordinates
x0 = wp.float32(i0) * x_scale - size[0]
y0 = wp.float32(j0) * y_scale - size[1]
x1 = wp.float32(i1) * x_scale - size[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very minor:
I guess this should be more optimized
x1 = x0 + x_scale
y1 = y0 + y_scale

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the logic, but isn't it more confusing that it is now?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both are equivalent to me. But if you think the original version is better, it's fine for me.

y1 = wp.float32(j1) * y_scale - size[1]

# Get height values at corners from hfield_data
base_addr = m.hfield_adr[dataid]
z00 = m.hfield_data[base_addr + j0 * ncol + i0]
z01 = m.hfield_data[base_addr + j1 * ncol + i0]
z10 = m.hfield_data[base_addr + j0 * ncol + i1]
z11 = m.hfield_data[base_addr + j1 * ncol + i1]

# Scale heights from range [0, 1] to [0, z_top]
z_top = size[2]
z00 = z00 * z_top
z01 = z01 * z_top
z10 = z10 * z_top
z11 = z11 * z_top

# Set bottom z-value
z_bottom = -size[3]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why minus here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# Compress 6 prism vertices into 3x3 matrix
# See get_hfield_prism_vertex() for the details
return wp.mat33(
x0,
y0,
z00,
x1,
y1,
z11,
wp.where(tri_index % 2, 1.0, 0.0),
wp.where(tri_index % 2, z10, z01),
z_bottom,
)


@wp.func
def get_hfield_prism_vertex(prism: wp.mat33, vert_index: int) -> wp.vec3:
"""Extracts vertices from a compressed triangular prism representation.

The compression scheme stores a 6-vertex triangular prism using a 3x3 matrix:
- prism[0] = First vertex (x,y,z) - corner (i,j)
- prism[1] = Second vertex (x,y,z) - corner (i+1,j+1)
- prism[2,0] = Triangle type flag: 0 for even triangle (using corner (i,j+1)),
non-zero for odd triangle (using corner (i+1,j))
- prism[2,1] = Z-coordinate of the third vertex
- prism[2,2] = Z-coordinate used for all bottom vertices (common z)

In this way, we can reconstruct all 6 vertices of the prism by reusing
coordinates from the stored vertices.

Args:
prism: 3x3 compressed representation of a triangular prism
vert_index: Index of vertex to extract (0-5)

Returns:
The 3D coordinates of the requested vertex
"""
if vert_index == 0 or vert_index == 1:
return prism[vert_index] # First two vertices stored directly

if vert_index == 2: # Third vertex
if prism[2][0] == 0: # Even triangle (i,j+1)
return wp.vec3(prism[0][0], prism[1][1], prism[2][1])
else: # Odd triangle (i+1,j)
return wp.vec3(prism[1][0], prism[0][1], prism[2][1])

if vert_index == 3 or vert_index == 4: # Bottom vertices below 0 and 1
return wp.vec3(prism[vert_index - 3][0], prism[vert_index - 3][1], prism[2][2])

if vert_index == 5: # Bottom vertex below 2
if prism[2][0] == 0: # Even triangle
return wp.vec3(prism[0][0], prism[1][1], prism[2][2])
else: # Odd triangle
return wp.vec3(prism[1][0], prism[0][1], prism[2][2])
Loading