diff --git a/mujoco_warp/_src/collision_box.py b/mujoco_warp/_src/collision_box.py deleted file mode 100644 index 562cf6ff..00000000 --- a/mujoco_warp/_src/collision_box.py +++ /dev/null @@ -1,651 +0,0 @@ -# Copyright 2025 The Newton Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - - -import math -from typing import Any - -import warp as wp - -from .collision_primitive import contact_params -from .collision_primitive import write_contact -from .math import make_frame -from .types import Data -from .types import GeomType -from .types import Model -from .types import vec5 - -BOX_BOX_BLOCK_DIM = 32 - - -_HUGE_VAL = 1e6 -_TINY_VAL = 1e-6 - - -class vec16b(wp.types.vector(length=16, dtype=wp.int8)): - pass - - -class mat43f(wp.types.matrix(shape=(4, 3), dtype=wp.float32)): - pass - - -class mat83f(wp.types.matrix(shape=(8, 3), dtype=wp.float32)): - pass - - -class mat16_3f(wp.types.matrix(shape=(16, 3), dtype=wp.float32)): - pass - - -Box = mat83f - - -@wp.func -def _argmin(a: Any) -> wp.int32: - amin = wp.int32(0) - vmin = wp.float32(a[0]) - for i in range(1, len(a)): - if a[i] < vmin: - amin = i - vmin = a[i] - return amin - - -@wp.func -def box_normals(i: int) -> wp.vec3: - direction = wp.where(i < 3, -1.0, 1.0) - mod = i % 3 - if mod == 0: - return wp.vec3(0.0, direction, 0.0) - if mod == 1: - return wp.vec3(0.0, 0.0, direction) - return wp.vec3(-direction, 0.0, 0.0) - - -@wp.func -def box(R: wp.mat33, t: wp.vec3, size: wp.vec3) -> Box: - """Get a transformed box""" - x = size[0] - y = size[1] - z = size[2] - m = Box() - for i in range(8): - ix = wp.where(i & 4, x, -x) - iy = wp.where(i & 2, y, -y) - iz = wp.where(i & 1, z, -z) - m[i] = R @ wp.vec3(ix, iy, iz) + t - return m - - -@wp.func -def box_face_verts(box: Box, idx: int) -> mat43f: - """Get the quad corresponding to a box face""" - if idx == 0: - verts = wp.vec4i(0, 4, 5, 1) - if idx == 1: - verts = wp.vec4i(0, 2, 6, 4) - if idx == 2: - verts = wp.vec4i(6, 7, 5, 4) - if idx == 3: - verts = wp.vec4i(2, 3, 7, 6) - if idx == 4: - verts = wp.vec4i(1, 5, 7, 3) - if idx == 5: - verts = wp.vec4i(0, 1, 3, 2) - - m = mat43f() - for i in range(4): - m[i] = box[verts[i]] - return m - - -@wp.func -def get_box_axis(axis_idx: int, R: wp.mat33): - """Get the axis at index axis_idx. - R: rotation matrix from a to b - Axes 0-12 are face normals of boxes a & b - Axes 12-21 are edge cross products.""" - if axis_idx < 6: # a faces - axis = R @ wp.vec3(box_normals(axis_idx)) - is_degenerate = False - elif axis_idx < 12: # b faces - axis = wp.vec3(box_normals(axis_idx - 6)) - is_degenerate = False - else: # edges cross products - assert axis_idx < 21 - edges = axis_idx - 12 - axis_a, axis_b = edges / 3, edges % 3 - edge_a = wp.transpose(R)[axis_a] - if axis_b == 0: - axis = wp.vec3(0.0, -edge_a[2], edge_a[1]) - elif axis_b == 1: - axis = wp.vec3(edge_a[2], 0.0, -edge_a[0]) - else: - axis = wp.vec3(-edge_a[1], edge_a[0], 0.0) - is_degenerate = wp.length_sq(axis) < _TINY_VAL - return wp.normalize(axis), is_degenerate - - -@wp.func -def get_box_axis_support(axis: wp.vec3, degenerate_axis: bool, a: Box, b: Box): - """Get the overlap (or separating distance if negative) along `axis`, and the sign.""" - axis_d = wp.vec3d(axis) - support_a_max, support_b_max = wp.float32(-_HUGE_VAL), wp.float32(-_HUGE_VAL) - support_a_min, support_b_min = wp.float32(_HUGE_VAL), wp.float32(_HUGE_VAL) - for i in range(8): - vert_a = wp.vec3d(a[i]) - vert_b = wp.vec3d(b[i]) - proj_a = wp.float32(wp.dot(vert_a, axis_d)) - proj_b = wp.float32(wp.dot(vert_b, axis_d)) - support_a_max = wp.max(support_a_max, proj_a) - support_b_max = wp.max(support_b_max, proj_b) - support_a_min = wp.min(support_a_min, proj_a) - support_b_min = wp.min(support_b_min, proj_b) - dist1 = support_a_max - support_b_min - dist2 = support_b_max - support_a_min - dist = wp.where(degenerate_axis, _HUGE_VAL, wp.min(dist1, dist2)) - sign = wp.where(dist1 > dist2, -1, 1) - return dist, sign - - -@wp.struct -class AxisSupport: - best_dist: wp.float32 - best_sign: wp.int8 - best_idx: wp.int8 - - -@wp.func -def reduce_axis_support(a: AxisSupport, b: AxisSupport): - return wp.where(a.best_dist > b.best_dist, b, a) - - -@wp.func -def face_axis_alignment(a: wp.vec3, R: wp.mat33) -> wp.int32: - """Find the box faces most aligned with the axis `a`""" - max_dot = wp.float32(0.0) - max_idx = wp.int32(0) - for i in range(6): - d = wp.dot(R @ box_normals(i), a) - if d > max_dot: - max_dot = d - max_idx = i - return max_idx - - -@wp.kernel(enable_backward=False) -def _box_box( - # Model: - geom_type: wp.array(dtype=int), - geom_condim: wp.array(dtype=int), - geom_priority: wp.array(dtype=int), - geom_solmix: wp.array(dtype=float), - geom_solref: wp.array(dtype=wp.vec2), - geom_solimp: wp.array(dtype=vec5), - geom_size: wp.array(dtype=wp.vec3), - geom_friction: wp.array(dtype=wp.vec3), - geom_margin: wp.array(dtype=float), - geom_gap: wp.array(dtype=float), - pair_dim: wp.array(dtype=int), - pair_solref: wp.array(dtype=wp.vec2), - pair_solreffriction: wp.array(dtype=wp.vec2), - pair_solimp: wp.array(dtype=vec5), - pair_margin: wp.array(dtype=float), - pair_gap: wp.array(dtype=float), - pair_friction: wp.array(dtype=vec5), - # Data in: - nconmax_in: int, - geom_xpos_in: wp.array2d(dtype=wp.vec3), - geom_xmat_in: wp.array2d(dtype=wp.mat33), - collision_pair_in: wp.array(dtype=wp.vec2i), - collision_pairid_in: wp.array(dtype=int), - collision_worldid_in: wp.array(dtype=int), - ncollision_in: wp.array(dtype=int), - # In: - num_kernels_in: int, - # Data out: - ncon_out: wp.array(dtype=int), - contact_dist_out: wp.array(dtype=float), - contact_pos_out: wp.array(dtype=wp.vec3), - contact_frame_out: wp.array(dtype=wp.mat33), - contact_includemargin_out: wp.array(dtype=float), - contact_dim_out: wp.array(dtype=int), - contact_friction_out: wp.array(dtype=vec5), - contact_solref_out: wp.array(dtype=wp.vec2), - contact_solreffriction_out: wp.array(dtype=wp.vec2), - contact_solimp_out: wp.array(dtype=vec5), - contact_geom_out: wp.array(dtype=wp.vec2i), - contact_worldid_out: wp.array(dtype=int), -): - """Calculates contacts between pairs of boxes.""" - tid, axis_idx = wp.tid() - - for bp_idx in range(tid, min(ncollision_in[0], nconmax_in), num_kernels_in): - geoms = collision_pair_in[bp_idx] - - ga, gb = geoms[0], geoms[1] - - if geom_type[ga] != int(GeomType.BOX.value) or geom_type[gb] != int(GeomType.BOX.value): - continue - - worldid = collision_worldid_in[bp_idx] - - geoms, margin, gap, condim, friction, solref, solreffriction, solimp = contact_params( - geom_condim, - geom_priority, - geom_solmix, - geom_solref, - geom_solimp, - geom_friction, - geom_margin, - geom_gap, - pair_dim, - pair_solref, - pair_solreffriction, - pair_solimp, - pair_margin, - pair_gap, - pair_friction, - collision_pair_in, - collision_pairid_in, - tid, - ) - - # transformations - a_pos, b_pos = geom_xpos_in[worldid, ga], geom_xpos_in[worldid, gb] - a_mat, b_mat = geom_xmat_in[worldid, ga], geom_xmat_in[worldid, gb] - b_mat_inv = wp.transpose(b_mat) - trans_atob = b_mat_inv @ (a_pos - b_pos) - rot_atob = b_mat_inv @ a_mat - - a_size = geom_size[ga] - b_size = geom_size[gb] - a = box(rot_atob, trans_atob, a_size) - b = box(wp.identity(3, wp.float32), wp.vec3(0.0), b_size) - - # box-box implementation - - # Inlined def collision_axis_tiled( a: Box, b: Box, R: wp.mat33, axis_idx: wp.int32,): - # Finds the axis of minimum separation. - # a: Box a vertices, in frame b - # b: Box b vertices, in frame b - # R: rotation matrix from a to b - # Returns: - # best_axis: vec3 - # best_sign: int32 - # best_idx: int32 - R = rot_atob - - # launch tiled with block_dim=21 - if axis_idx > 20: - continue - - axis, degenerate_axis = get_box_axis(axis_idx, R) - axis_dist, axis_sign = get_box_axis_support(axis, degenerate_axis, a, b) - - supports = wp.tile(AxisSupport(axis_dist, wp.int8(axis_sign), wp.int8(axis_idx))) - - face_supports = wp.tile_view(supports, offset=(0,), shape=(12,)) - edge_supports = wp.tile_view(supports, offset=(12,), shape=(9,)) - - face_supports_red = wp.tile_reduce(reduce_axis_support, face_supports) - edge_supports_red = wp.tile_reduce(reduce_axis_support, edge_supports) - - face = face_supports_red[0] - edge = edge_supports_red[0] - - if axis_idx > 0: # single thread - continue - - # choose the best separating axis - face_axis, _ = get_box_axis(wp.int32(face.best_idx), R) - best_axis = wp.vec3(face_axis) - best_sign = wp.int32(face.best_sign) - best_idx = wp.int32(face.best_idx) - best_dist = wp.float32(face.best_dist) - - if edge.best_dist < face.best_dist: - edge_axis, _ = get_box_axis(wp.int32(edge.best_idx), R) - if wp.abs(wp.dot(face_axis, edge_axis)) < 0.99: - best_axis = edge_axis - best_sign = wp.int32(edge.best_sign) - best_idx = wp.int32(edge.best_idx) - best_dist = wp.float32(edge.best_dist) - # end inlined collision_axis_tiled - - # if axis_idx != 0: - # continue - if best_dist < 0: - continue - - # get the (reference) face most aligned with the separating axis - a_max = face_axis_alignment(best_axis, rot_atob) - b_max = face_axis_alignment(best_axis, wp.identity(3, wp.float32)) - - sep_axis = wp.float32(best_sign) * best_axis - - if best_sign > 0: - b_min = (b_max + 3) % 6 - dist, pos = _create_contact_manifold( - box_face_verts(a, a_max), - rot_atob @ box_normals(a_max), - box_face_verts(b, b_min), - box_normals(b_min), - ) - else: - a_min = (a_max + 3) % 6 - dist, pos = _create_contact_manifold( - box_face_verts(b, b_max), - box_normals(b_max), - box_face_verts(a, a_min), - rot_atob @ box_normals(a_min), - ) - - # For edge contacts, we use the clipped face point, mainly for performance - # reasons. For small penetration, the clipped face point is roughly the edge - # contact point. - if best_idx > 11: # is_edge_contact - idx = _argmin(dist) - dist = wp.vec4f(dist[idx], 1.0, 1.0, 1.0) - for i in range(4): - pos[i] = pos[idx] - - margin = wp.max(geom_margin[ga], geom_margin[gb]) - for i in range(4): - pos_glob = b_mat @ pos[i] + b_pos - n_glob = b_mat @ sep_axis - - write_contact( - nconmax_in, - dist[i], - pos_glob, - make_frame(n_glob), - margin, - gap, - condim, - friction, - solref, - solreffriction, - solimp, - geoms, - worldid, - ncon_out, - contact_dist_out, - contact_pos_out, - contact_frame_out, - contact_includemargin_out, - contact_friction_out, - contact_solref_out, - contact_solreffriction_out, - contact_solimp_out, - contact_dim_out, - contact_geom_out, - contact_worldid_out, - ) - - -@wp.func -def _closest_segment_point_plane(a: wp.vec3, b: wp.vec3, p0: wp.vec3, plane_normal: wp.vec3) -> wp.vec3: - """Gets the closest point between a line segment and a plane. - - Args: - a: first line segment point - b: second line segment point - p0: point on plane - plane_normal: plane normal - - Returns: - closest point between the line segment and the plane - """ - # Parametrize a line segment as S(t) = a + t * (b - a), plug it into the plane - # equation dot(n, S(t)) - d = 0, then solve for t to get the line-plane - # intersection. We then clip t to be in [0, 1] to be on the line segment. - n = plane_normal - d = wp.dot(p0, n) # shortest distance from origin to plane - denom = wp.dot(n, (b - a)) - t = (d - wp.dot(n, a)) / (denom + wp.where(denom == 0.0, _TINY_VAL, 0.0)) - t = wp.clamp(t, 0.0, 1.0) - segment_point = a + t * (b - a) - - return segment_point - - -@wp.func -def _project_poly_onto_plane(poly: Any, poly_n: wp.vec3, plane_n: wp.vec3, plane_pt: wp.vec3): - """Projects poly1 onto the poly2 plane along poly2's normal.""" - d = wp.dot(plane_pt, plane_n) - denom = wp.dot(poly_n, plane_n) - qn_scaled = poly_n / (denom + wp.where(denom == 0.0, _TINY_VAL, 0.0)) - - for i in range(len(poly)): - poly[i] = poly[i] + (d - wp.dot(poly[i], plane_n)) * qn_scaled - return poly - - -@wp.func -def _clip_edge_to_quad(subject_poly: mat43f, clipping_poly: mat43f, clipping_normal: wp.vec3): - p0 = mat43f() - p1 = mat43f() - mask = wp.vec4b() - for edge_idx in range(4): - subject_p0 = subject_poly[(edge_idx + 3) % 4] - subject_p1 = subject_poly[edge_idx] - - any_both_in_front = wp.int32(0) - clipped0_dist_max = wp.float32(-_HUGE_VAL) - clipped1_dist_max = wp.float32(-_HUGE_VAL) - clipped_p0_distmax = wp.vec3(0.0) - clipped_p1_distmax = wp.vec3(0.0) - - for clipping_edge_idx in range(4): - clipping_p0 = clipping_poly[(clipping_edge_idx + 3) % 4] - clipping_p1 = clipping_poly[clipping_edge_idx] - edge_normal = wp.cross(clipping_p1 - clipping_p0, clipping_normal) - - p0_in_front = wp.dot(subject_p0 - clipping_p0, edge_normal) > _TINY_VAL - p1_in_front = wp.dot(subject_p1 - clipping_p0, edge_normal) > _TINY_VAL - candidate_clipped_p = _closest_segment_point_plane(subject_p0, subject_p1, clipping_p1, edge_normal) - clipped_p0 = wp.where(p0_in_front, candidate_clipped_p, subject_p0) - clipped_p1 = wp.where(p1_in_front, candidate_clipped_p, subject_p1) - clipped_dist_p0 = wp.dot(clipped_p0 - subject_p0, subject_p1 - subject_p0) - clipped_dist_p1 = wp.dot(clipped_p1 - subject_p1, subject_p0 - subject_p1) - any_both_in_front |= wp.int32(p0_in_front and p1_in_front) - - if clipped_dist_p0 > clipped0_dist_max: - clipped0_dist_max = clipped_dist_p0 - clipped_p0_distmax = clipped_p0 - - if clipped_dist_p1 > clipped1_dist_max: - clipped1_dist_max = clipped_dist_p1 - clipped_p1_distmax = clipped_p1 - new_p0 = wp.where(any_both_in_front, subject_p0, clipped_p0_distmax) - new_p1 = wp.where(any_both_in_front, subject_p1, clipped_p1_distmax) - - mask_val = wp.int8( - wp.where( - wp.dot(subject_p0 - subject_p1, new_p0 - new_p1) < 0, - 0, - wp.int32(not any_both_in_front), - ) - ) - - p0[edge_idx] = new_p0 - p1[edge_idx] = new_p1 - mask[edge_idx] = mask_val - return p0, p1, mask - - -@wp.func -def _clip_quad(subject_quad: mat43f, subject_normal: wp.vec3, clipping_quad: mat43f, clipping_normal: wp.vec3): - """Clips a subject quad against a clipping quad. - Serial implementation. - """ - - subject_clipped_p0, subject_clipped_p1, subject_mask = _clip_edge_to_quad(subject_quad, clipping_quad, clipping_normal) - clipping_proj = _project_poly_onto_plane(clipping_quad, clipping_normal, subject_normal, subject_quad[0]) - clipping_clipped_p0, clipping_clipped_p1, clipping_mask = _clip_edge_to_quad(clipping_proj, subject_quad, subject_normal) - - clipped = mat16_3f() - mask = vec16b() - for i in range(4): - clipped[i] = subject_clipped_p0[i] - clipped[i + 4] = clipping_clipped_p0[i] - clipped[i + 8] = subject_clipped_p1[i] - clipped[i + 12] = clipping_clipped_p1[i] - mask[i] = subject_mask[i] - mask[i + 4] = clipping_mask[i] - mask[i + 8] = subject_mask[i] - mask[i + 8 + 4] = clipping_mask[i] - - return clipped, mask - - -# TODO(ca): tiling variant -@wp.func -def _manifold_points(poly: Any, mask: Any, clipping_norm: wp.vec3) -> wp.vec4b: - """Chooses four points on the polygon with approximately maximal area. Return the indices""" - n = len(poly) - - a_idx = wp.int32(0) - a_mask = wp.int8(mask[0]) - for i in range(n): - if mask[i] >= a_mask: - a_idx = i - a_mask = mask[i] - a = poly[a_idx] - - b_idx = wp.int32(0) - b_dist = wp.float32(-_HUGE_VAL) - for i in range(n): - dist = wp.length_sq(poly[i] - a) + wp.where(mask[i], 0.0, -_HUGE_VAL) - if dist >= b_dist: - b_idx = i - b_dist = dist - b = poly[b_idx] - - ab = wp.cross(clipping_norm, a - b) - - c_idx = wp.int32(0) - c_dist = wp.float32(-_HUGE_VAL) - for i in range(n): - ap = a - poly[i] - dist = wp.abs(wp.dot(ap, ab)) + wp.where(mask[i], 0.0, -_HUGE_VAL) - if dist >= c_dist: - c_idx = i - c_dist = dist - c = poly[c_idx] - - ac = wp.cross(clipping_norm, a - c) - bc = wp.cross(clipping_norm, b - c) - - d_idx = wp.int32(0) - d_dist = wp.float32(-2.0 * _HUGE_VAL) - for i in range(n): - ap = a - poly[i] - dist_ap = wp.abs(wp.dot(ap, ac)) + wp.where(mask[i], 0.0, -_HUGE_VAL) - bp = b - poly[i] - dist_bp = wp.abs(wp.dot(bp, bc)) + wp.where(mask[i], 0.0, -_HUGE_VAL) - if dist_ap + dist_bp >= d_dist: - d_idx = i - d_dist = dist_ap + dist_bp - d = poly[d_idx] - return wp.vec4b(wp.int8(a_idx), wp.int8(b_idx), wp.int8(c_idx), wp.int8(d_idx)) - - -@wp.func -def _create_contact_manifold(clipping_quad: mat43f, clipping_normal: wp.vec3, subject_quad: mat43f, subject_normal: wp.vec3): - # Clip the subject (incident) face onto the clipping (reference) face. - # The incident points are clipped points on the subject polygon. - incident, mask = _clip_quad(subject_quad, subject_normal, clipping_quad, clipping_normal) - - clipping_normal_neg = -clipping_normal - d = wp.dot(clipping_quad[0], clipping_normal_neg) + _TINY_VAL - - for i in range(16): - if wp.dot(incident[i], clipping_normal_neg) < d: - mask[i] = wp.int8(0) - - ref = _project_poly_onto_plane(incident, clipping_normal, clipping_normal, clipping_quad[0]) - - # Choose four contact points. - best = _manifold_points(ref, mask, clipping_normal) - contact_pts = mat43f() - dist = wp.vec4f() - - for i in range(4): - idx = wp.int32(best[i]) - contact_pt = ref[idx] - contact_pts[i] = contact_pt - penetration_dir = incident[idx] - contact_pt - penetration = wp.dot(penetration_dir, clipping_normal) - dist[i] = wp.where(mask[idx], penetration, 1.0) - - return dist, contact_pts - - -def box_box_narrowphase( - m: Model, - d: Data, -): - """Calculates contacts between pairs of boxes.""" - kernel_ratio = 16 - nthread = math.ceil(d.nconmax / kernel_ratio) # parallel threads excluding tile dim - wp.launch_tiled( - kernel=_box_box, - dim=nthread, - inputs=[ - m.geom_type, - m.geom_condim, - m.geom_priority, - m.geom_solmix, - m.geom_solref, - m.geom_solimp, - m.geom_size, - m.geom_friction, - m.geom_margin, - m.geom_gap, - m.pair_dim, - m.pair_solref, - m.pair_solreffriction, - m.pair_solimp, - m.pair_margin, - m.pair_gap, - m.pair_friction, - d.nconmax, - d.geom_xpos, - d.geom_xmat, - d.collision_pair, - d.collision_pairid, - d.collision_worldid, - d.ncollision, - nthread, - ], - outputs=[ - d.ncon, - d.contact.dist, - d.contact.pos, - d.contact.frame, - d.contact.includemargin, - d.contact.dim, - d.contact.friction, - d.contact.solref, - d.contact.solreffriction, - d.contact.solimp, - d.contact.geom, - d.contact.worldid, - ], - block_dim=BOX_BOX_BLOCK_DIM, - ) diff --git a/mujoco_warp/_src/collision_driver.py b/mujoco_warp/_src/collision_driver.py index 799051cc..b50bab96 100644 --- a/mujoco_warp/_src/collision_driver.py +++ b/mujoco_warp/_src/collision_driver.py @@ -17,7 +17,6 @@ import warp as wp -from .collision_box import box_box_narrowphase from .collision_convex import gjk_narrowphase from .collision_primitive import primitive_narrowphase from .types import MJ_MAXVAL @@ -446,4 +445,3 @@ def collision(m: Model, d: Data): # TODO(team) switch between collision functions and GJK/EPA here gjk_narrowphase(m, d) primitive_narrowphase(m, d) - box_box_narrowphase(m, d) diff --git a/mujoco_warp/_src/collision_driver_test.py b/mujoco_warp/_src/collision_driver_test.py index 8327dc15..dfece796 100644 --- a/mujoco_warp/_src/collision_driver_test.py +++ b/mujoco_warp/_src/collision_driver_test.py @@ -39,6 +39,58 @@ class CollisionTest(parameterized.TestCase): """, + "box_box_vf": """ + + + + + + + + + + + + """, + "box_box_vf_flat": """ + + + + + + + + + + + + """, + "box_box_ee": """ + + + + + + + + + + + + """, + "box_box_ee_deep": """ + + + + + + + + + + + + """, "plane_sphere": """ diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index 0bffc3a4..692feed7 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -28,6 +28,18 @@ wp.set_module_options({"enable_backward": False}) +class vec8f(wp.types.vector(length=8, dtype=wp.float32)): + pass + + +class mat43f(wp.types.matrix(shape=(4, 3), dtype=wp.float32)): + pass + + +class mat83f(wp.types.matrix(shape=(8, 3), dtype=wp.float32)): + pass + + @wp.struct class Geom: pos: wp.vec3 @@ -1718,6 +1730,497 @@ def capsule_box( ) +@wp.func +def _compute_rotmore(face_idx: int) -> wp.mat33: + rotmore = wp.mat33(0.0) + + if face_idx == 0: + rotmore[0, 2] = -1.0 + rotmore[1, 1] = +1.0 + rotmore[2, 0] = +1.0 + elif face_idx == 1: + rotmore[0, 0] = +1.0 + rotmore[1, 2] = -1.0 + rotmore[2, 1] = +1.0 + elif face_idx == 2: + rotmore[0, 0] = +1.0 + rotmore[1, 1] = +1.0 + rotmore[2, 2] = +1.0 + elif face_idx == 3: + rotmore[0, 2] = +1.0 + rotmore[1, 1] = +1.0 + rotmore[2, 0] = -1.0 + elif face_idx == 4: + rotmore[0, 0] = +1.0 + rotmore[1, 2] = +1.0 + rotmore[2, 1] = -1.0 + elif face_idx == 5: + rotmore[0, 0] = -1.0 + rotmore[1, 1] = +1.0 + rotmore[2, 2] = -1.0 + + return rotmore + + +@wp.func +def box_box( + # Data in: + nconmax_in: int, + # In: + box1: Geom, + box2: Geom, + worldid: int, + margin: float, + gap: float, + condim: int, + friction: vec5, + solref: wp.vec2f, + solreffriction: wp.vec2f, + solimp: vec5, + geoms: wp.vec2i, + # Data out: + ncon_out: wp.array(dtype=int), + contact_dist_out: wp.array(dtype=float), + contact_pos_out: wp.array(dtype=wp.vec3), + contact_frame_out: wp.array(dtype=wp.mat33), + contact_includemargin_out: wp.array(dtype=float), + contact_friction_out: wp.array(dtype=vec5), + contact_solref_out: wp.array(dtype=wp.vec2), + contact_solreffriction_out: wp.array(dtype=wp.vec2), + contact_solimp_out: wp.array(dtype=vec5), + contact_dim_out: wp.array(dtype=int), + contact_geom_out: wp.array(dtype=wp.vec2i), + contact_worldid_out: wp.array(dtype=int), +): + # Compute transforms between box's frames + + pos21 = wp.transpose(box1.rot) @ (box2.pos - box1.pos) + pos12 = wp.transpose(box2.rot) @ (box1.pos - box2.pos) + + rot21 = wp.transpose(box1.rot) @ box2.rot + rot12 = wp.transpose(rot21) + + rot21abs = wp.matrix_from_rows(wp.abs(rot21[0]), wp.abs(rot21[1]), wp.abs(rot21[2])) + rot12abs = wp.transpose(rot21abs) + + plen2 = rot21abs @ box2.size + plen1 = rot12abs @ box1.size + + # Compute axis of maximum separation + s_sum_3 = 3.0 * (box1.size + box2.size) + separation = wp.float32(margin + s_sum_3[0] + s_sum_3[1] + s_sum_3[2]) + axis_code = wp.int32(-1) + + # First test: consider boxes' face normals + for i in range(3): + c1 = -wp.abs(pos21[i]) + box1.size[i] + plen2[i] + + c2 = -wp.abs(pos12[i]) + box2.size[i] + plen1[i] + + if c1 < -margin or c2 < -margin: + return + + if c1 < separation: + separation = c1 + axis_code = i + 3 * wp.int32(pos21[i] < 0) + 0 # Face of box1 + if c2 < separation: + separation = c2 + axis_code = i + 3 * wp.int32(pos12[i] < 0) + 6 # Face of box2 + + clnorm = wp.vec3(0.0) + inv = wp.bool(False) + cle1 = wp.int32(0) + cle2 = wp.int32(0) + + # Second test: consider cross products of boxes' edges + for i in range(3): + for j in range(3): + # Compute cross product of box edges (potential separating axis) + if i == 0: + cross_axis = wp.vec3(0.0, -rot12[j, 2], rot12[j, 1]) + elif i == 1: + cross_axis = wp.vec3(rot12[j, 2], 0.0, -rot12[j, 0]) + else: + cross_axis = wp.vec3(-rot12[j, 1], rot12[j, 0], 0.0) + + cross_length = wp.length(cross_axis) + if cross_length < MJ_MINVAL: + continue + + cross_axis /= cross_length + + box_dist = wp.dot(pos21, cross_axis) + c3 = wp.float32(0.0) + + # Project box half-sizes onto the potential separating axis + for k in range(3): + if k != i: + c3 += box1.size[k] * wp.abs(cross_axis[k]) + if k != j: + c3 += box2.size[k] * rot21abs[i, 3 - k - j] / cross_length + + c3 -= wp.abs(box_dist) + + # Early exit: no collision if separated along this axis + if c3 < -margin: + return + + # Track minimum separation and which edge-edge pair it occurs on + if c3 < separation * (1.0 - 1e-12): + separation = c3 + # Determine which corners/edges are closest + cle1 = 0 + cle2 = 0 + + for k in range(3): + if k != i and (int(cross_axis[k] > 0) ^ int(box_dist < 0)): + cle1 += 1 << k + if k != j and (int(rot21[i, 3 - k - j] > 0) ^ int(box_dist < 0) ^ int((k - j + 3) % 3 == 1)): + cle2 += 1 << k + + axis_code = 12 + i * 3 + j + clnorm = cross_axis + inv = box_dist < 0 + + # No axis with separation < margin found + if axis_code == -1: + return + + points = mat83f() + depth = vec8f() + max_con_pair = 8 + # 8 contacts should suffice for most configurations + + if axis_code < 12: + # Handle face-vertex collision + face_idx = axis_code % 6 + box_idx = axis_code / 6 + rotmore = _compute_rotmore(face_idx) + + r = rotmore @ wp.where(box_idx, rot12, rot21) + p = rotmore @ wp.where(box_idx, pos12, pos21) + ss = wp.abs(rotmore @ wp.where(box_idx, box2.size, box1.size)) + s = wp.where(box_idx, box1.size, box2.size) + rt = wp.transpose(r) + + lx, ly, hz = ss[0], ss[1], ss[2] + p[2] -= hz + + clcorner = wp.int32(0) # corner of non-face box with least axis separation + + for i in range(3): + if r[2, i] < 0: + clcorner += 1 << i + + lp = p + for i in range(wp.static(3)): + lp += rt[i] * s[i] * wp.where(clcorner & 1 << i, 1.0, -1.0) + + m = wp.int32(1) + dirs = wp.int32(0) + + cn1 = wp.vec3(0.0) + cn2 = wp.vec3(0.0) + + for i in range(3): + if wp.abs(r[2, i]) < 0.5: + if not dirs: + cn1 = rt[i] * s[i] * wp.where(clcorner & (1 << i), -2.0, 2.0) + else: + cn2 = rt[i] * s[i] * wp.where(clcorner & (1 << i), -2.0, 2.0) + + dirs += 1 + + k = dirs * dirs + + # Find potential contact points + + n = wp.int32(0) + + for i in range(k): + for q in range(2): + # lines_a and lines_b (lines between corners) computed on the fly + lav = lp + wp.where(i < 2, wp.vec3(0.0), wp.where(i == 2, cn1, cn2)) + lbv = wp.where(i == 0 or i == 3, cn1, cn2) + + if wp.abs(lbv[q]) > MJ_MINVAL: + br = 1.0 / lbv[q] + for j in range(-1, 2, 2): + l = ss[q] * wp.float32(j) + c1 = (l - lav[q]) * br + if c1 < 0 or c1 > 1: + continue + c2 = lav[1 - q] + lbv[1 - q] * c1 + if wp.abs(c2) > ss[1 - q]: + continue + + points[n] = lav + c1 * lbv + n += 1 + + if dirs == 2: + ax = cn1[0] + bx = cn2[0] + ay = cn1[1] + by = cn2[1] + C = 1.0 / (ax * by - bx * ay) + + for i in range(4): + llx = wp.where(i / 2, lx, -lx) + lly = wp.where(i % 2, ly, -ly) + + x = llx - lp[0] + y = lly - lp[1] + + u = (x * by - y * bx) * C + v = (y * ax - x * ay) * C + + if u > 0 and v > 0 and u < 1 and v < 1: + points[n] = wp.vec3(llx, lly, lp[2] + u * cn1[2] + v * cn2[2]) + n += 1 + + for i in range(1 << dirs): + tmpv = lp + wp.float32(i & 1) * cn1 + wp.float32((i & 2) != 0) * cn2 + if tmpv[0] > -lx and tmpv[0] < lx and tmpv[1] > -ly and tmpv[1] < ly: + points[n] = tmpv + n += 1 + + m = n + n = wp.int32(0) + + for i in range(m): + if points[i][2] > margin: + continue + if i != n: + points[n] = points[i] + + points[n, 2] *= 0.5 + depth[n] = points[n, 2] + n += 1 + + # Set up contact frame + rw = wp.where(box_idx, box2.rot, box1.rot) @ wp.transpose(rotmore) + pw = wp.where(box_idx, box2.pos, box1.pos) + normal = wp.where(box_idx, -1.0, 1.0) * wp.transpose(rw)[2] + + else: + # Handle edge-edge collision + edge1 = (axis_code - 12) / 3 + edge2 = (axis_code - 12) % 3 + + # Set up non-contacting edges ax1, ax2 for box2 and pax1, pax2 for box 1 + ax1 = wp.int(1 - (edge2 & 1)) + ax2 = wp.int(2 - (edge2 & 2)) + + pax1 = wp.int(1 - (edge1 & 1)) + pax2 = wp.int(2 - (edge1 & 2)) + + if rot21abs[edge1, ax1] < rot21abs[edge1, ax2]: + ax1, ax2 = ax2, ax1 + + if rot12abs[edge2, pax1] < rot12abs[edge2, pax2]: + pax1, pax2 = pax2, pax1 + + rotmore = _compute_rotmore(wp.where(cle1 & (1 << pax2), pax2, pax2 + 3)) + + # Transform coordinates for edge-edge contact calculation + p = rotmore @ pos21 + rnorm = rotmore @ clnorm + r = rotmore @ rot21 + rt = wp.transpose(r) + s = wp.abs(wp.transpose(rotmore) @ box1.size) + + lx, ly, hz = s[0], s[1], s[2] + p[2] -= hz + + # Calculate closest box2 face + + points[0] = ( + p + + rt[ax1] * box2.size[ax1] * wp.where(cle2 & (1 << ax1), 1.0, -1.0) + + rt[ax2] * box2.size[ax2] * wp.where(cle2 & (1 << ax2), 1.0, -1.0) + ) + points[1] = points[0] - rt[edge2] * box2.size[edge2] + points[0] += rt[edge2] * box2.size[edge2] + + points[2] = ( + p + + rt[ax1] * box2.size[ax1] * wp.where(cle2 & (1 << ax1), -1.0, 1.0) + + rt[ax2] * box2.size[ax2] * wp.where(cle2 & (1 << ax2), 1.0, -1.0) + ) + + points[3] = points[2] - rt[edge2] * box2.size[edge2] + points[2] += rt[edge2] * box2.size[edge2] + + n = 4 + + # Set up coordinate axes for contact face of box2 + axi_lp = points[0] + axi_cn1 = points[1] - points[0] + axi_cn2 = points[2] - points[0] + + # Check if contact normal is valid + if wp.abs(rnorm[2]) < MJ_MINVAL: + return # Shouldn't happen + + # Calculate inverse normal for projection + innorm = wp.where(inv, -1.0, 1.0) / rnorm[2] + + pu = mat43f() + + # Project points onto contact plane + for i in range(4): + pu[i] = points[i] + c_scl = points[i, 2] * wp.where(inv, -1.0, 1.0) * innorm + points[i] -= rnorm * c_scl + + pts_lp = points[0] + pts_cn1 = points[1] - points[0] + pts_cn2 = points[2] - points[0] + + n = wp.int32(0) + + for i in range(4): + for q in range(2): + la = pts_lp[q] + wp.where(i < 2, 0.0, wp.where(i == 2, pts_cn1[q], pts_cn2[q])) + lb = wp.where(i == 0 or i == 3, pts_cn1[q], pts_cn2[q]) + lc = pts_lp[1 - q] + wp.where(i < 2, 0.0, wp.where(i == 2, pts_cn1[1 - q], pts_cn2[1 - q])) + ld = wp.where(i == 0 or i == 3, pts_cn1[1 - q], pts_cn2[1 - q]) + + # linesu_a and linesu_b (lines between corners) computed on the fly + lua = axi_lp + wp.where(i < 2, wp.vec3(0.0), wp.where(i == 2, axi_cn1, axi_cn2)) + lub = wp.where(i == 0 or i == 3, axi_cn1, axi_cn2) + + if wp.abs(lb) > MJ_MINVAL: + br = 1.0 / lb + for j in range(-1, 2, 2): + if n == max_con_pair: + break + l = s[q] * wp.float32(j) + c1 = (l - la) * br + if c1 < 0 or c1 > 1: + continue + c2 = lc + ld * c1 + if wp.abs(c2) > s[1 - q]: + continue + if (lua[2] + lub[2] * c1) * innorm > margin: + continue + + points[n] = lua * 0.5 + c1 * lub * 0.5 + points[n, q] += 0.5 * l + points[n, 1 - q] += 0.5 * c2 + depth[n] = points[n, 2] * innorm * 2.0 + n += 1 + + nl = n + + ax = pts_cn1[0] + bx = pts_cn2[0] + ay = pts_cn1[1] + by = pts_cn2[1] + C = 1.0 / (ax * by - bx * ay) + + for i in range(4): + if n == max_con_pair: + break + llx = wp.where(i / 2, lx, -lx) + lly = wp.where(i % 2, ly, -ly) + + x = llx - pts_lp[0] + y = lly - pts_lp[1] + + u = (x * by - y * bx) * C + v = (y * ax - x * ay) * C + + if nl == 0: + if (u < 0 or u > 0) and (v < 0 or v > 1): + continue + elif u < 0 or v < 0 or u > 1 or v > 1: + continue + + u = wp.clamp(u, 0.0, 1.0) + v = wp.clamp(v, 0.0, 1.0) + w = 1.0 - u - v + vtmp = pu[0] * w + pu[1] * u + pu[2] * v + + points[n] = wp.vec3(llx, lly, 0.0) + + vtmp2 = points[n] - vtmp + tc1 = wp.length_sq(vtmp2) + if vtmp[2] > 0 and tc1 > margin * margin: + continue + + points[n] = 0.5 * (points[n] + vtmp) + + depth[n] = wp.sqrt(tc1) * wp.where(vtmp[2] < 0, -1.0, 1.0) + n += 1 + + nf = n + + for i in range(4): + if n >= max_con_pair: + break + x = pu[i, 0] + y = pu[i, 1] + if nl == 0 and nf != 0: + if (x < -lx or x > lx) and (y < -ly or y > ly): + continue + elif x < -lx or x > lx or y < -ly or y > ly: + continue + + c1 = wp.float32(0) + + for j in range(2): + if pu[i, j] < -s[j]: + c1 += (pu[i, j] + s[j]) * (pu[i, j] + s[j]) + elif pu[i, j] > s[j]: + c1 += (pu[i, j] - s[j]) * (pu[i, j] - s[j]) + + c1 += pu[i, 2] * innorm * pu[i, 2] * innorm + + if pu[i, 2] > 0 and c1 > margin * margin: + continue + + tmp_p = wp.vec3(pu[i, 0], pu[i, 1], 0.0) + + for j in range(2): + if pu[i, j] < -s[j]: + tmp_p[j] = -s[j] * 0.5 + elif pu[i, j] > s[j]: + tmp_p[j] = +s[j] * 0.5 + + tmp_p += pu[i] + points[n] = tmp_p * 0.5 + + depth[n] = wp.sqrt(c1) * wp.where(pu[i, 2] < 0, -1.0, 1.0) + n += 1 + + # Set up contact data for all points + rw = box1.rot @ wp.transpose(rotmore) + pw = box1.pos + normal = wp.where(inv, -1.0, 1.0) * rw @ rnorm + + frame = make_frame(normal) + coff = wp.atomic_add(ncon_out, 0, n) + + for i in range(min(nconmax_in - coff, n)): + points[i, 2] += hz + pos = rw @ points[i] + pw + + cid = coff + i + + contact_dist_out[cid] = depth[i] + contact_pos_out[cid] = pos + contact_frame_out[cid] = frame + contact_geom_out[cid] = geoms + contact_worldid_out[cid] = worldid + contact_includemargin_out[cid] = margin - gap + contact_dim_out[cid] = condim + contact_friction_out[cid] = friction + contact_solref_out[cid] = solref + contact_solreffriction_out[cid] = solreffriction + contact_solimp_out[cid] = solimp + + @wp.kernel def _primitive_narrowphase( # Model: @@ -2061,6 +2564,33 @@ def _primitive_narrowphase( contact_geom_out, contact_worldid_out, ) + elif type1 == int(GeomType.BOX.value) and type2 == int(GeomType.BOX.value): + box_box( + nconmax_in, + geom1, + geom2, + worldid, + margin, + gap, + condim, + friction, + solref, + solreffriction, + solimp, + geoms, + ncon_out, + contact_dist_out, + contact_pos_out, + contact_frame_out, + contact_includemargin_out, + contact_friction_out, + contact_solref_out, + contact_solreffriction_out, + contact_solimp_out, + contact_dim_out, + contact_geom_out, + contact_worldid_out, + ) elif type1 == int(GeomType.CAPSULE.value) and type2 == int(GeomType.BOX.value): capsule_box( nconmax_in,