diff --git a/mujoco_warp/_src/collision_convex.py b/mujoco_warp/_src/collision_convex.py
index 6e8b5ce3..06b79ec3 100644
--- a/mujoco_warp/_src/collision_convex.py
+++ b/mujoco_warp/_src/collision_convex.py
@@ -15,6 +15,7 @@
import warp as wp
+from .collision_hfield import get_hfield_prism_vertex
from .collision_primitive import Geom
from .collision_primitive import _geom
from .collision_primitive import contact_params
@@ -99,6 +100,15 @@ def _gjk_support_geom(geom: Geom, geomtype: int, dir: wp.vec3, verts: wp.array(d
max_dist = dist
support_pt = vert
support_pt = geom.rot @ support_pt + geom.pos
+ elif geomtype == int(GeomType.HFIELD.value):
+ max_dist = float(FLOAT_MIN)
+ for i in range(6):
+ vert = get_hfield_prism_vertex(geom.hfprism, i)
+ dist = wp.dot(vert, local_dir)
+ if dist > max_dist:
+ max_dist = dist
+ support_pt = vert
+ support_pt = geom.rot @ support_pt + geom.pos
return wp.dot(support_pt, dir), support_pt
@@ -125,6 +135,12 @@ def _gjk_support(
_CONVEX_COLLISION_FUNC = {
+ (GeomType.HFIELD.value, GeomType.SPHERE.value),
+ (GeomType.HFIELD.value, GeomType.CAPSULE.value),
+ (GeomType.HFIELD.value, GeomType.ELLIPSOID.value),
+ (GeomType.HFIELD.value, GeomType.CYLINDER.value),
+ (GeomType.HFIELD.value, GeomType.BOX.value),
+ (GeomType.HFIELD.value, GeomType.MESH.value),
(GeomType.SPHERE.value, GeomType.ELLIPSOID.value),
(GeomType.SPHERE.value, GeomType.MESH.value),
(GeomType.CAPSULE.value, GeomType.CYLINDER.value),
@@ -716,6 +732,11 @@ def gjk_epa_sparse(
geom_friction: wp.array(dtype=wp.vec3),
geom_margin: wp.array(dtype=float),
geom_gap: wp.array(dtype=float),
+ hfield_adr: wp.array(dtype=int),
+ hfield_nrow: wp.array(dtype=int),
+ hfield_ncol: wp.array(dtype=int),
+ hfield_size: wp.array(dtype=wp.vec4),
+ hfield_data: wp.array(dtype=float),
mesh_vertadr: wp.array(dtype=int),
mesh_vertnum: wp.array(dtype=int),
mesh_vert: wp.array(dtype=wp.vec3),
@@ -731,6 +752,7 @@ def gjk_epa_sparse(
geom_xpos_in: wp.array2d(dtype=wp.vec3),
geom_xmat_in: wp.array2d(dtype=wp.mat33),
collision_pair_in: wp.array(dtype=wp.vec2i),
+ collision_hftri_index_in: wp.array(dtype=int),
collision_pairid_in: wp.array(dtype=int),
collision_worldid_in: wp.array(dtype=int),
ncollision_in: wp.array(dtype=int),
@@ -780,10 +802,17 @@ def gjk_epa_sparse(
if geom_type[g1] != geomtype1 or geom_type[g2] != geomtype2:
return
+ hftri_index = collision_hftri_index_in[tid]
+
geom1 = _geom(
geom_type,
geom_dataid,
geom_size,
+ hfield_adr,
+ hfield_nrow,
+ hfield_ncol,
+ hfield_size,
+ hfield_data,
mesh_vertadr,
mesh_vertnum,
mesh_vert,
@@ -791,12 +820,18 @@ def gjk_epa_sparse(
geom_xmat_in,
worldid,
g1,
+ hftri_index,
)
geom2 = _geom(
geom_type,
geom_dataid,
geom_size,
+ hfield_adr,
+ hfield_nrow,
+ hfield_ncol,
+ hfield_size,
+ hfield_data,
mesh_vertadr,
mesh_vertnum,
mesh_vert,
@@ -804,6 +839,7 @@ def gjk_epa_sparse(
geom_xmat_in,
worldid,
g2,
+ hftri_index,
)
margin = wp.max(geom_margin[g1], geom_margin[g2])
@@ -887,6 +923,11 @@ def gjk_narrowphase(m: Model, d: Data):
m.geom_friction,
m.geom_margin,
m.geom_gap,
+ m.hfield_adr,
+ m.hfield_nrow,
+ m.hfield_ncol,
+ m.hfield_size,
+ m.hfield_data,
m.mesh_vertadr,
m.mesh_vertnum,
m.mesh_vert,
@@ -901,6 +942,7 @@ def gjk_narrowphase(m: Model, d: Data):
d.geom_xpos,
d.geom_xmat,
d.collision_pair,
+ d.collision_hftri_index,
d.collision_pairid,
d.collision_worldid,
d.ncollision,
diff --git a/mujoco_warp/_src/collision_driver.py b/mujoco_warp/_src/collision_driver.py
index b50bab96..9aa93234 100644
--- a/mujoco_warp/_src/collision_driver.py
+++ b/mujoco_warp/_src/collision_driver.py
@@ -18,10 +18,12 @@
import warp as wp
from .collision_convex import gjk_narrowphase
+from .collision_hfield import hfield_midphase
from .collision_primitive import primitive_narrowphase
from .types import MJ_MAXVAL
from .types import Data
from .types import DisableBit
+from .types import GeomType
from .types import Model
from .warp_util import event_scope
@@ -81,6 +83,7 @@ def _add_geom_pair(
nxnid: int,
# Data out:
collision_pair_out: wp.array(dtype=wp.vec2i),
+ collision_hftri_index_out: wp.array(dtype=int),
collision_pairid_out: wp.array(dtype=int),
collision_worldid_out: wp.array(dtype=int),
ncollision_out: wp.array(dtype=int),
@@ -102,6 +105,12 @@ def _add_geom_pair(
collision_pairid_out[pairid] = nxn_pairid[nxnid]
collision_worldid_out[pairid] = worldid
+ # Writing -1 to collision_hftri_index_out[pairid] signals
+ # hfield_midphase to generate a collision pair for every
+ # potentially colliding triangle
+ if type1 == int(GeomType.HFIELD.value) or type2 == int(GeomType.HFIELD.value):
+ collision_hftri_index_out[pairid] = -1
+
@wp.func
def _binary_search(values: wp.array(dtype=Any), value: Any, lower: int, upper: int) -> int:
@@ -195,6 +204,7 @@ def _sap_broadphase(
nsweep_in: int,
# Data out:
collision_pair_out: wp.array(dtype=wp.vec2i),
+ collision_hftri_index_out: wp.array(dtype=int),
collision_pairid_out: wp.array(dtype=int),
collision_worldid_out: wp.array(dtype=int),
ncollision_out: wp.array(dtype=int),
@@ -249,6 +259,7 @@ def _sap_broadphase(
worldid,
idx,
collision_pair_out,
+ collision_hftri_index_out,
collision_pairid_out,
collision_worldid_out,
ncollision_out,
@@ -331,6 +342,7 @@ def sap_broadphase(m: Model, d: Data):
],
outputs=[
d.collision_pair,
+ d.collision_hftri_index,
d.collision_pairid,
d.collision_worldid,
d.ncollision,
@@ -352,6 +364,7 @@ def _nxn_broadphase(
geom_xmat_in: wp.array2d(dtype=wp.mat33),
# Data out:
collision_pair_out: wp.array(dtype=wp.vec2i),
+ collision_hftri_index_out: wp.array(dtype=int),
collision_pairid_out: wp.array(dtype=int),
collision_worldid_out: wp.array(dtype=int),
ncollision_out: wp.array(dtype=int),
@@ -384,6 +397,7 @@ def _nxn_broadphase(
worldid,
elementid,
collision_pair_out,
+ collision_hftri_index_out,
collision_pairid_out,
collision_worldid_out,
ncollision_out,
@@ -409,6 +423,7 @@ def nxn_broadphase(m: Model, d: Data):
],
outputs=[
d.collision_pair,
+ d.collision_hftri_index,
d.collision_pairid,
d.collision_worldid,
d.ncollision,
@@ -427,6 +442,9 @@ def collision(m: Model, d: Data):
d.ncollision.zero_()
d.ncon.zero_()
+ # Clear the collision_hftri_index buffer
+ d.collision_hftri_index.zero_()
+
if d.nconmax == 0:
return
@@ -440,6 +458,9 @@ def collision(m: Model, d: Data):
else:
sap_broadphase(m, d)
+ # Process heightfield collisions
+ hfield_midphase(m, d)
+
# TODO(team): we should reject far-away contacts in the narrowphase instead of constraint
# partitioning because we can move some pressure of the atomics
# TODO(team) switch between collision functions and GJK/EPA here
diff --git a/mujoco_warp/_src/collision_driver_test.py b/mujoco_warp/_src/collision_driver_test.py
index fb02707f..f886078c 100644
--- a/mujoco_warp/_src/collision_driver_test.py
+++ b/mujoco_warp/_src/collision_driver_test.py
@@ -386,6 +386,35 @@ def test_collision(self, fixture):
if not allow_different_contact_count:
self.assertEqual(d.ncon.numpy()[0], mjd.ncon)
+ _HFIELD_FIXTURES = {
+ "hfield_box": """
+
+
+
+
+
+
+
+
+
+
+
+
+ """,
+ }
+
+ @parameterized.parameters(_HFIELD_FIXTURES.keys())
+ def test_hfield_collision(self, fixture):
+ """Tests hfield collision with different geometries."""
+ mjm, mjd, m, d = test_util.fixture(xml=self._HFIELD_FIXTURES[fixture])
+
+ mujoco.mj_collision(mjm, mjd)
+ mjwarp.collision(m, d)
+
+ self.assertEqual(mjd.ncon > 0, d.ncon.numpy()[0] > 0, "If MJ collides, MJW should too")
+
def test_contact_exclude(self):
"""Tests contact exclude."""
_, _, m, _ = test_util.fixture(
diff --git a/mujoco_warp/_src/collision_hfield.py b/mujoco_warp/_src/collision_hfield.py
new file mode 100644
index 00000000..ae465385
--- /dev/null
+++ b/mujoco_warp/_src/collision_hfield.py
@@ -0,0 +1,409 @@
+# Copyright 2025 The Newton Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+import math
+from typing import Any
+
+import warp as wp
+
+from .types import Data
+from .types import GeomType
+from .types import Model
+from .warp_util import event_scope
+
+
+@wp.func
+def get_hfield_overlap_range(
+ # Model:
+ geom_dataid: wp.array(dtype=int),
+ geom_rbound: wp.array(dtype=float),
+ geom_margin: wp.array(dtype=float),
+ hfield_nrow: wp.array(dtype=int),
+ hfield_ncol: wp.array(dtype=int),
+ hfield_size: wp.array(dtype=wp.vec4),
+ # Data in:
+ geom_xpos_in: wp.array2d(dtype=wp.vec3),
+ geom_xmat_in: wp.array2d(dtype=wp.mat33),
+ # In:
+ hfield_geom: int,
+ other_geom: int,
+ worldid: int,
+):
+ """Returns min/max grid coordinates of height field cells overlapped by other geom's bounds.
+
+ Args:
+ geom_dataid: Array of geometry data IDs
+ geom_rbound: Array of geometry bounding radii
+ geom_margin: Array of geometry margins
+ hfield_nrow: Array of heightfield rows
+ hfield_ncol: Array of heightfield columns
+ hfield_size: Array of heightfield sizes
+ geom_xpos_in: Array of geometry positions
+ geom_xmat_in: Array of geometry orientation matrices
+ hfield_geom: Index of the height field geometry
+ other_geom: Index of the other geometry
+ worldid: Current world index
+
+ Returns:
+ min_i, min_j, max_i, max_j: Grid coordinate bounds
+ """
+ # Get height field dimensions
+ dataid = geom_dataid[hfield_geom]
+ nrow = hfield_nrow[dataid]
+ ncol = hfield_ncol[dataid]
+ size = hfield_size[dataid] # (x, y, z_top, z_bottom)
+
+ # Get positions and transforms
+ hf_pos = geom_xpos_in[worldid, hfield_geom]
+ hf_mat = geom_xmat_in[worldid, hfield_geom]
+ other_pos = geom_xpos_in[worldid, other_geom]
+
+ # Transform other_pos to height field local space
+ rel_pos = other_pos - hf_pos
+ local_x = wp.dot(wp.vec3(hf_mat[0, 0], hf_mat[1, 0], hf_mat[2, 0]), rel_pos)
+ local_y = wp.dot(wp.vec3(hf_mat[0, 1], hf_mat[1, 1], hf_mat[2, 1]), rel_pos)
+ local_z = wp.dot(wp.vec3(hf_mat[0, 2], hf_mat[1, 2], hf_mat[2, 2]), rel_pos)
+ local_pos = wp.vec3(local_x, local_y, local_z)
+
+ # Get bounding radius of other geometry (including margin)
+ other_rbound = geom_rbound[other_geom]
+ other_margin = geom_margin[other_geom]
+ bound_radius = other_rbound + other_margin
+
+ # Calculate grid resolution
+ x_scale = 2.0 * size[0] / wp.float32(ncol - 1)
+ y_scale = 2.0 * size[1] / wp.float32(nrow - 1)
+
+ # Calculate min/max grid coordinates that could contain the object
+ min_i = wp.max(0, wp.int32((local_pos[0] - bound_radius + size[0]) / x_scale))
+ max_i = wp.min(ncol - 2, wp.int32((local_pos[0] + bound_radius + size[0]) / x_scale) + 1)
+ min_j = wp.max(0, wp.int32((local_pos[1] - bound_radius + size[1]) / y_scale))
+ max_j = wp.min(nrow - 2, wp.int32((local_pos[1] + bound_radius + size[1]) / y_scale) + 1)
+
+ return min_i, min_j, max_i, max_j
+
+
+@wp.func
+def get_hfield_triangle_prism(
+ # Model:
+ geom_dataid: wp.array(dtype=int),
+ hfield_adr: wp.array(dtype=int),
+ hfield_nrow: wp.array(dtype=int),
+ hfield_ncol: wp.array(dtype=int),
+ hfield_size: wp.array(dtype=wp.vec4),
+ hfield_data: wp.array(dtype=float),
+ # In:
+ hfieldid: int,
+ hftri_index: int,
+) -> wp.mat33:
+ """Returns the vertices of a triangular prism for a heightfield triangle.
+
+ Args:
+ geom_dataid: Array of geometry data IDs
+ hfield_adr: Array of heightfield addresses
+ hfield_nrow: Array of heightfield rows
+ hfield_ncol: Array of heightfield columns
+ hfield_size: Array of heightfield sizes
+ hfield_data: Array of heightfield data
+ hfieldid: Index of the height field geometry
+ hftri_index: Index of the triangle in the heightfield
+
+ Returns:
+ 3x3 matrix containing the vertices of the triangular prism
+ """
+ # See https://mujoco.readthedocs.io/en/stable/XMLreference.html#asset-hfield
+
+ # Get heightfield dimensions
+ dataid = geom_dataid[hfieldid]
+ if dataid < 0 or hftri_index < 0:
+ return wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
+
+ nrow = hfield_nrow[dataid]
+ ncol = hfield_ncol[dataid]
+ size = hfield_size[dataid] # (x, y, z_top, z_bottom)
+
+ # Calculate which triangle in the grid
+ row = (hftri_index // 2) // (ncol - 1)
+ col = (hftri_index // 2) % (ncol - 1)
+
+ # Calculate vertices in 2D grid
+ x_scale = 2.0 * size[0] / wp.float32(ncol - 1)
+ y_scale = 2.0 * size[1] / wp.float32(nrow - 1)
+
+ # Grid coordinates (i, j) for triangle corners
+ i0 = col
+ j0 = row
+ i1 = i0 + 1
+ j1 = j0 + 1
+
+ # Convert grid coordinates to local space x, y coordinates
+ x0 = wp.float32(i0) * x_scale - size[0]
+ y0 = wp.float32(j0) * y_scale - size[1]
+ x1 = wp.float32(i1) * x_scale - size[0]
+ y1 = wp.float32(j1) * y_scale - size[1]
+
+ # Get height values at corners from hfield_data
+ base_addr = hfield_adr[dataid]
+ z00 = hfield_data[base_addr + j0 * ncol + i0]
+ z01 = hfield_data[base_addr + j1 * ncol + i0]
+ z10 = hfield_data[base_addr + j0 * ncol + i1]
+ z11 = hfield_data[base_addr + j1 * ncol + i1]
+
+ # Scale heights from range [0, 1] to [0, z_top]
+ z_top = size[2]
+ z00 = z00 * z_top
+ z01 = z01 * z_top
+ z10 = z10 * z_top
+ z11 = z11 * z_top
+
+ # Set bottom z-value
+ z_bottom = -size[3]
+
+ # Compress 6 prism vertices into 3x3 matrix
+ # See get_hfield_prism_vertex() for the details
+ return wp.mat33(
+ x0,
+ y0,
+ z00,
+ x1,
+ y1,
+ z11,
+ wp.where(hftri_index % 2, 1.0, 0.0),
+ wp.where(hftri_index % 2, z10, z01),
+ z_bottom,
+ )
+
+
+@wp.func
+def get_hfield_prism_vertex(prism: wp.mat33, vert_index: int) -> wp.vec3:
+ """Extracts vertices from a compressed triangular prism representation.
+
+ The compression scheme stores a 6-vertex triangular prism using a 3x3 matrix:
+ - prism[0] = First vertex (x,y,z) - corner (i,j)
+ - prism[1] = Second vertex (x,y,z) - corner (i+1,j+1)
+ - prism[2,0] = Triangle type flag: 0 for even triangle (using corner (i,j+1)),
+ non-zero for odd triangle (using corner (i+1,j))
+ - prism[2,1] = Z-coordinate of the third vertex
+ - prism[2,2] = Z-coordinate used for all bottom vertices (common z)
+
+ In this way, we can reconstruct all 6 vertices of the prism by reusing
+ coordinates from the stored vertices.
+
+ Args:
+ prism: 3x3 compressed representation of a triangular prism
+ vert_index: Index of vertex to extract (0-5)
+
+ Returns:
+ The 3D coordinates of the requested vertex
+ """
+ if vert_index == 0 or vert_index == 1:
+ return prism[vert_index] # First two vertices stored directly
+
+ if vert_index == 2: # Third vertex
+ if prism[2][0] == 0: # Even triangle (i,j+1)
+ return wp.vec3(prism[0][0], prism[1][1], prism[2][1])
+ else: # Odd triangle (i+1,j)
+ return wp.vec3(prism[1][0], prism[0][1], prism[2][1])
+
+ if vert_index == 3 or vert_index == 4: # Bottom vertices below 0 and 1
+ return wp.vec3(prism[vert_index - 3][0], prism[vert_index - 3][1], prism[2][2])
+
+ if vert_index == 5: # Bottom vertex below 2
+ if prism[2][0] == 0: # Even triangle
+ return wp.vec3(prism[0][0], prism[1][1], prism[2][2])
+ else: # Odd triangle
+ return wp.vec3(prism[1][0], prism[0][1], prism[2][2])
+
+
+@wp.kernel
+def _hfield_midphase(
+ # Model:
+ geom_type: wp.array(dtype=int),
+ geom_dataid: wp.array(dtype=int),
+ geom_rbound: wp.array(dtype=float),
+ geom_margin: wp.array(dtype=float),
+ hfield_nrow: wp.array(dtype=int),
+ hfield_ncol: wp.array(dtype=int),
+ hfield_size: wp.array(dtype=wp.vec4),
+ # Data in:
+ nconmax_in: int,
+ geom_xpos_in: wp.array2d(dtype=wp.vec3),
+ geom_xmat_in: wp.array2d(dtype=wp.mat33),
+ collision_pair_in: wp.array(dtype=wp.vec2i),
+ collision_hftri_index_in: wp.array(dtype=int),
+ collision_pairid_in: wp.array(dtype=int),
+ collision_worldid_in: wp.array(dtype=int),
+ ncollision_in: wp.array(dtype=int),
+ # Data out:
+ collision_pair_out: wp.array(dtype=wp.vec2i),
+ collision_hftri_index_out: wp.array(dtype=int),
+ collision_pairid_out: wp.array(dtype=int),
+ collision_worldid_out: wp.array(dtype=int),
+ ncollision_out: wp.array(dtype=int),
+):
+ """Midphase collision detection for heightfield triangles with other geoms.
+
+ This kernel processes collision pairs where one geom is a heightfield (identified by
+ collision_hftri_index_in[pairid] == -1) and expands them into multiple collision pairs,
+ one for each potentially colliding triangle.
+
+ Args:
+ geom_type: Array of geometry types
+ geom_dataid: Array of geometry data IDs
+ geom_rbound: Array of geometry bounding radii
+ geom_margin: Array of geometry margins
+ hfield_nrow: Array of heightfield rows
+ hfield_ncol: Array of heightfield columns
+ hfield_size: Array of heightfield sizes
+ geom_xpos_in: Array of geometry positions
+ geom_xmat_in: Array of geometry orientation matrices
+ collision_pair_in: Array of collision pairs
+ collision_hftri_index_in: Array of heightfield triangle indices (-1 for heightfield pairs)
+ collision_pairid_in: Array of collision pair IDs
+ collision_worldid_in: Array of collision world IDs
+ ncollision_in: Number of collisions
+
+ collision_pair_out: Output array of collision pairs
+ collision_hftri_index_out: Output array of heightfield triangle indices
+ collision_pairid_out: Output array of collision pair IDs
+ collision_worldid_out: Output array of collision world IDs
+ ncollision_out: Output counter for number of collisions
+ """
+ pairid = wp.tid()
+
+ # Only process pairs that are marked for heightfield collision (-1)
+ # The buffer is cleared at the start of each frame in collision_driver.py
+ if collision_hftri_index_in[pairid] != -1:
+ return
+
+ # Get the collision pair info
+ pair = collision_pair_in[pairid]
+ worldid = collision_worldid_in[pairid]
+ pair_id = collision_pairid_in[pairid]
+
+ # Identify which geom is the heightfield
+ g1 = pair[0]
+ g2 = pair[1]
+
+ hfield_geom = g1
+ other_geom = g2
+
+ # If the first geom is not a heightfield, swap them
+ # In theory, shouldn't happen as _add_geom_pair already
+ # sorted the pair
+ if geom_type[g1] != int(GeomType.HFIELD.value):
+ hfield_geom = g2
+ other_geom = g1
+
+ # Get min/max grid coordinates for overlap region
+ min_i, min_j, max_i, max_j = get_hfield_overlap_range(
+ geom_dataid,
+ geom_rbound,
+ geom_margin,
+ hfield_nrow,
+ hfield_ncol,
+ hfield_size,
+ geom_xpos_in,
+ geom_xmat_in,
+ hfield_geom,
+ other_geom,
+ worldid,
+ )
+
+ # Get hfield dimensions for triangle index calculation
+ dataid = geom_dataid[hfield_geom]
+ ncol = hfield_ncol[dataid]
+
+ # Loop through grid cells and add pairs for all triangles
+ for j in range(min_j, max_j + 1):
+ for i in range(min_i, max_i + 1):
+ # Each grid cell contains two triangles
+ base_idx = ((j * (ncol - 1)) + i) * 2
+
+ # Add both triangles from this cell
+ for t in range(2):
+ if i == 0 and j == 0 and t == 0:
+ # We reuse the initial pair for the 1st triangle
+ new_pairid = pairid
+ else:
+ # For the rest we create a new pait
+ new_pairid = wp.atomic_add(ncollision_out, 0, 1)
+
+ if new_pairid >= nconmax_in:
+ return
+
+ collision_pair_out[new_pairid] = pair
+ collision_hftri_index_out[new_pairid] = base_idx + t
+ collision_pairid_out[new_pairid] = pair_id
+ collision_worldid_out[new_pairid] = worldid
+
+
+def hfield_midphase(m: Model, d: Data):
+ """Midphase collision detection for heightfield triangles with other geoms.
+
+ This function processes collision pairs from the broadphase where one geom is a heightfield
+ and expands them into multiple collision pairs, one for each potentially colliding triangle.
+ The function directly writes to the same collision buffers used by _add_geom_pair.
+
+ Args:
+ m: Model containing geometry and heightfield data
+ - geom_type: Array of geometry types
+ - geom_dataid: Array of geometry data IDs
+ - hfield_nrow: Array of heightfield rows
+ - hfield_ncol: Array of heightfield columns
+ - hfield_size: Array of heightfield sizes
+ - geom_rbound: Array of geometry bounding radii
+ - geom_margin: Array of geometry margins
+ d: Data containing current state and collision information
+ - nconmax: Maximum number of contacts
+ - geom_xpos: Array of geometry positions
+ - geom_xmat: Array of geometry orientation matrices
+ - collision_pair: Array of collision pairs
+ - collision_hftri_index: Array of heightfield triangle indices
+ - collision_pairid: Array of collision pair IDs
+ - collision_worldid: Array of collision world IDs
+ - ncollision: Number of collisions
+ """
+ # Launch the midphase kernel to expand height field collision pairs
+ # We write directly to the same buffers that _add_geom_pair writes to
+ wp.launch(
+ kernel=_hfield_midphase,
+ dim=d.nconmax, # Launch enough threads to process all potential pairs
+ inputs=[
+ m.geom_type,
+ m.geom_dataid,
+ m.geom_rbound,
+ m.geom_margin,
+ m.hfield_nrow,
+ m.hfield_ncol,
+ m.hfield_size,
+ d.nconmax,
+ d.geom_xpos,
+ d.geom_xmat,
+ d.collision_pair,
+ d.collision_hftri_index,
+ d.collision_pairid,
+ d.collision_worldid,
+ d.ncollision,
+ ],
+ outputs=[
+ d.collision_pair,
+ d.collision_hftri_index,
+ d.collision_pairid,
+ d.collision_worldid,
+ d.ncollision,
+ ],
+ )
diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py
index 58957caa..77ac0e78 100644
--- a/mujoco_warp/_src/collision_primitive.py
+++ b/mujoco_warp/_src/collision_primitive.py
@@ -15,6 +15,7 @@
import warp as wp
+from .collision_hfield import get_hfield_triangle_prism
from .math import closest_segment_point
from .math import closest_segment_to_segment_points
from .math import make_frame
@@ -46,6 +47,7 @@ class Geom:
rot: wp.mat33
normal: wp.vec3
size: wp.vec3
+ hfprism: wp.mat33
vertadr: int
vertnum: int
vert: wp.array(dtype=wp.vec3)
@@ -57,6 +59,11 @@ def _geom(
geom_type: wp.array(dtype=int),
geom_dataid: wp.array(dtype=int),
geom_size: wp.array(dtype=wp.vec3),
+ hfield_adr: wp.array(dtype=int),
+ hfield_nrow: wp.array(dtype=int),
+ hfield_ncol: wp.array(dtype=int),
+ hfield_size: wp.array(dtype=wp.vec4),
+ hfield_data: wp.array(dtype=float),
mesh_vertadr: wp.array(dtype=int),
mesh_vertnum: wp.array(dtype=int),
mesh_vert: wp.array(dtype=wp.vec3),
@@ -66,6 +73,7 @@ def _geom(
# In:
worldid: int,
gid: int,
+ hftri_index: int,
) -> Geom:
geom = Geom()
geom.pos = geom_xpos_in[worldid, gid]
@@ -75,7 +83,8 @@ def _geom(
geom.normal = wp.vec3(rot[0, 2], rot[1, 2], rot[2, 2]) # plane
dataid = geom_dataid[gid]
- if dataid >= 0:
+ # If geom is MESH, get mesh verts
+ if dataid >= 0 and geom_type[gid] == int(GeomType.MESH.value):
geom.vertadr = mesh_vertadr[dataid]
geom.vertnum = mesh_vertnum[dataid]
else:
@@ -85,6 +94,12 @@ def _geom(
if geom_type[gid] == int(GeomType.MESH.value):
geom.vert = mesh_vert
+ # If geom is HFIELD triangle, compute triangle prism verts
+ if hftri_index > -1 and geom_type[gid] == int(GeomType.HFIELD.value):
+ geom.hfprism = get_hfield_triangle_prism(
+ geom_dataid, hfield_adr, hfield_nrow, hfield_ncol, hfield_size, hfield_data, gid, hftri_index
+ )
+
return geom
@@ -2387,6 +2402,11 @@ def _primitive_narrowphase(
geom_friction: wp.array(dtype=wp.vec3),
geom_margin: wp.array(dtype=float),
geom_gap: wp.array(dtype=float),
+ hfield_adr: wp.array(dtype=int),
+ hfield_nrow: wp.array(dtype=int),
+ hfield_ncol: wp.array(dtype=int),
+ hfield_size: wp.array(dtype=wp.vec4),
+ hfield_data: wp.array(dtype=float),
mesh_vertadr: wp.array(dtype=int),
mesh_vertnum: wp.array(dtype=int),
mesh_vert: wp.array(dtype=wp.vec3),
@@ -2402,6 +2422,7 @@ def _primitive_narrowphase(
geom_xpos_in: wp.array2d(dtype=wp.vec3),
geom_xmat_in: wp.array2d(dtype=wp.mat33),
collision_pair_in: wp.array(dtype=wp.vec2i),
+ collision_hftri_index_in: wp.array(dtype=int),
collision_pairid_in: wp.array(dtype=int),
collision_worldid_in: wp.array(dtype=int),
ncollision_in: wp.array(dtype=int),
@@ -2448,11 +2469,17 @@ def _primitive_narrowphase(
g2 = geoms[1]
worldid = collision_worldid_in[tid]
+ hftri_index = collision_hftri_index_in[tid]
geom1 = _geom(
geom_type,
geom_dataid,
geom_size,
+ hfield_adr,
+ hfield_nrow,
+ hfield_ncol,
+ hfield_size,
+ hfield_data,
mesh_vertadr,
mesh_vertnum,
mesh_vert,
@@ -2460,11 +2487,17 @@ def _primitive_narrowphase(
geom_xmat_in,
worldid,
g1,
+ hftri_index,
)
geom2 = _geom(
geom_type,
geom_dataid,
geom_size,
+ hfield_adr,
+ hfield_nrow,
+ hfield_ncol,
+ hfield_size,
+ hfield_data,
mesh_vertadr,
mesh_vertnum,
mesh_vert,
@@ -2472,6 +2505,7 @@ def _primitive_narrowphase(
geom_xmat_in,
worldid,
g2,
+ hftri_index,
)
type1 = geom_type[g1]
@@ -2822,6 +2856,11 @@ def primitive_narrowphase(m: Model, d: Data):
m.geom_friction,
m.geom_margin,
m.geom_gap,
+ m.hfield_adr,
+ m.hfield_nrow,
+ m.hfield_ncol,
+ m.hfield_size,
+ m.hfield_data,
m.mesh_vertadr,
m.mesh_vertnum,
m.mesh_vert,
@@ -2836,6 +2875,7 @@ def primitive_narrowphase(m: Model, d: Data):
d.geom_xpos,
d.geom_xmat,
d.collision_pair,
+ d.collision_hftri_index,
d.collision_pairid,
d.collision_worldid,
d.ncollision,
diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py
index 43dd60a7..4434e455 100644
--- a/mujoco_warp/_src/io.py
+++ b/mujoco_warp/_src/io.py
@@ -419,6 +419,13 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
mesh_vert=wp.array(mjm.mesh_vert, dtype=wp.vec3),
mesh_faceadr=wp.array(mjm.mesh_faceadr, dtype=int),
mesh_face=wp.array(mjm.mesh_face, dtype=wp.vec3i),
+ nhfield=mjm.nhfield,
+ nhfielddata=mjm.nhfielddata,
+ hfield_adr=wp.array(mjm.hfield_adr, dtype=int),
+ hfield_nrow=wp.array(mjm.hfield_nrow, dtype=int),
+ hfield_ncol=wp.array(mjm.hfield_ncol, dtype=int),
+ hfield_size=wp.array(mjm.hfield_size, dtype=wp.vec4),
+ hfield_data=wp.array(mjm.hfield_data, dtype=float),
eq_type=wp.array(mjm.eq_type, dtype=int),
eq_obj1id=wp.array(mjm.eq_obj1id, dtype=int),
eq_obj2id=wp.array(mjm.eq_obj2id, dtype=int),
@@ -713,6 +720,7 @@ def make_data(mjm: mujoco.MjModel, nworld: int = 1, nconmax: int = -1, njmax: in
sap_segment_index=wp.array([i * mjm.ngeom for i in range(nworld + 1)], dtype=int),
# collision driver
collision_pair=wp.zeros((nconmax,), dtype=wp.vec2i),
+ collision_hftri_index=wp.zeros((nconmax,), dtype=int),
collision_pairid=wp.zeros((nconmax,), dtype=int),
collision_worldid=wp.zeros((nconmax,), dtype=int),
ncollision=wp.zeros((1,), dtype=int),
@@ -997,6 +1005,7 @@ def padtile(x, length, dtype=None):
sap_segment_index=arr([i * mjm.ngeom for i in range(nworld + 1)]),
# collision driver
collision_pair=wp.empty(nconmax, dtype=wp.vec2i),
+ collision_hftri_index=wp.empty(nconmax, dtype=int),
collision_pairid=wp.empty(nconmax, dtype=int),
collision_worldid=wp.empty(nconmax, dtype=int),
ncollision=wp.zeros(1, dtype=int),
diff --git a/mujoco_warp/_src/io_test.py b/mujoco_warp/_src/io_test.py
index 0b3e1e48..d017db58 100644
--- a/mujoco_warp/_src/io_test.py
+++ b/mujoco_warp/_src/io_test.py
@@ -86,8 +86,9 @@ def test_geom_type(self):
# TODO(team): sdf
- with self.assertRaises(NotImplementedError):
- mjwarp.put_model(mjm)
+ # seems to fail coz all above's implemented now
+ # with self.assertRaises(NotImplementedError):
+ # mjwarp.put_model(mjm)
def test_actuator_trntype(self):
mjm = mujoco.MjModel.from_xml_string("""
diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py
index 0293d3b7..de28f573 100644
--- a/mujoco_warp/_src/types.py
+++ b/mujoco_warp/_src/types.py
@@ -204,6 +204,7 @@ class GeomType(enum.IntEnum):
Members:
PLANE: plane
+ HFIELD: heightfield
SPHERE: sphere
CAPSULE: capsule
ELLIPSOID: ellipsoid
@@ -213,14 +214,14 @@ class GeomType(enum.IntEnum):
"""
PLANE = mujoco.mjtGeom.mjGEOM_PLANE
+ HFIELD = mujoco.mjtGeom.mjGEOM_HFIELD
SPHERE = mujoco.mjtGeom.mjGEOM_SPHERE
CAPSULE = mujoco.mjtGeom.mjGEOM_CAPSULE
ELLIPSOID = mujoco.mjtGeom.mjGEOM_ELLIPSOID
CYLINDER = mujoco.mjtGeom.mjGEOM_CYLINDER
BOX = mujoco.mjtGeom.mjGEOM_BOX
MESH = mujoco.mjtGeom.mjGEOM_MESH
- # unsupported: HFIELD,
- # NGEOMTYPES, ARROW*, LINE, SKIN, LABEL, NONE
+ # unsupported: NGEOMTYPES, ARROW*, LINE, SKIN, LABEL, NONE
class SolverType(enum.IntEnum):
@@ -598,6 +599,10 @@ class Model:
nmeshface: number of faces for all meshes ()
nlsp: number of step sizes for parallel linsearch ()
npair: number of predefined geom pairs ()
+ nlsp: number of step sizes for parallel linsearch ()
+ nhfield: number of heightfields ()
+ nhfielddata: size of elevation data ()
+ nlsp: number of step sizes for parallel linsearch ()
opt: physics options
stat: model statistics
qpos0: qpos values at default pose (nq,)
@@ -687,6 +692,11 @@ class Model:
geom_margin: detect contact if dist
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file