Skip to content

Model batching #109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
9455124
move nworld to model.
adenzler-nvidia Apr 2, 2025
c990e0b
qpos0 with stride 0, seems to work well.
adenzler-nvidia Apr 3, 2025
2ba4bca
add ability to expand to real array
adenzler-nvidia Apr 3, 2025
815455a
converting qpos_spring
adenzler-nvidia Apr 3, 2025
1f57d9e
move to more descriptive naming
adenzler-nvidia Apr 3, 2025
78f1dcf
some more conversions
adenzler-nvidia Apr 3, 2025
b57aaf8
invweight0
adenzler-nvidia Apr 3, 2025
3ab0e6d
remove body_contype and body_conaffinity
adenzler-nvidia Apr 3, 2025
e18eee8
jnt_solref anf jnt_solimp
adenzler-nvidia Apr 3, 2025
df0236e
jnt_range
adenzler-nvidia Apr 3, 2025
e5d14b4
jnt_stiffness, jnt_margin, jnt_actfrcrange
adenzler-nvidia Apr 3, 2025
84cd8c9
geom_contype and conaffinity
adenzler-nvidia Apr 3, 2025
1bb6b62
geom_pos, geom_quat, geom_priority
adenzler-nvidia Apr 3, 2025
4615c9a
geom_solref, solimp, mix, friction, margin, gap
adenzler-nvidia Apr 3, 2025
8e04aa2
size pos, quat
adenzler-nvidia Apr 3, 2025
86a2f56
dof_armature
adenzler-nvidia Apr 3, 2025
59792a0
dof_damping
adenzler-nvidia Apr 3, 2025
7422a23
dof invweight and damping
adenzler-nvidia Apr 3, 2025
9dc58a6
final pieces
adenzler-nvidia Apr 3, 2025
31147e4
update documentation comments
adenzler-nvidia Apr 3, 2025
452b6cc
don't create temporary array
adenzler-nvidia Apr 3, 2025
bf32f95
better auto-expand
adenzler-nvidia Apr 3, 2025
fab9ac7
WIP test
adenzler-nvidia Apr 3, 2025
626eee2
Merge branch 'main' into dev/adenzler/model-batching-v2
adenzler-nvidia Apr 4, 2025
fabcd60
add tiling for new fields in put_data
adenzler-nvidia Apr 4, 2025
8c1433e
fixes after merging main
adenzler-nvidia Apr 4, 2025
0221529
add simple test
adenzler-nvidia Apr 4, 2025
4f42c8c
formatting
adenzler-nvidia Apr 4, 2025
da0cfbe
fix ruff errors
adenzler-nvidia Apr 4, 2025
51bdc48
adjust tolerance in smooth_test a bit to avoid flakyness in camlight …
adenzler-nvidia Apr 4, 2025
58e2067
remove accidental file commit
adenzler-nvidia Apr 4, 2025
a5b5cd3
Merge branch 'main' into dev/adenzler/model-batching-v2
adenzler-nvidia Apr 7, 2025
45dd091
some fixes after merging main
adenzler-nvidia Apr 7, 2025
3d240e5
more fixes
adenzler-nvidia Apr 7, 2025
c424f83
obvious mistake that is tricky to spot. Needs a better error.
adenzler-nvidia Apr 7, 2025
2f052e3
formatting
adenzler-nvidia Apr 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion contrib/jax_unroll.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
mjd.qvel = np.random.uniform(-0.01, 0.01, mjm.nv)
mujoco.mj_step(mjm, mjd, 3) # let dynamics get state significantly non-zero
mujoco.mj_forward(mjm, mjd)
m = mjwarp.put_model(mjm)
m = mjwarp.put_model(mjm, nworld=NWORLDS)
d = mjwarp.put_data(mjm, mjd, nworld=NWORLDS, nconmax=131012, njmax=131012 * 4)


Expand Down
3 changes: 2 additions & 1 deletion mujoco_warp/_src/broad_phase_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def test_nxn_broadphase(self):
np.testing.assert_allclose(d2.collision_pair.numpy()[2][1], 2)

# two worlds and four collisions
m3 = mjwarp.put_model(mjm, nworld=2)
d3 = mjwarp.make_data(mjm, nworld=2)
d3.geom_xpos = wp.array(
np.vstack(
Expand All @@ -216,7 +217,7 @@ def test_nxn_broadphase(self):
dtype=wp.vec3,
)

collision_driver.nxn_broadphase(m, d3)
collision_driver.nxn_broadphase(m3, d3)
np.testing.assert_allclose(d3.ncollision.numpy()[0], 4)
np.testing.assert_allclose(d3.collision_pair.numpy()[0][0], 0)
np.testing.assert_allclose(d3.collision_pair.numpy()[0][1], 1)
Expand Down
2 changes: 1 addition & 1 deletion mujoco_warp/_src/collision_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def box_box_kernel(
for i in range(4):
pos[i] = pos[idx]

margin = wp.max(m.geom_margin[ga], m.geom_margin[gb])
margin = wp.max(m.geom_margin[worldid, ga], m.geom_margin[worldid, gb])
for i in range(4):
pos_glob = b_mat @ pos[i] + b_pos
n_glob = b_mat @ sep_axis
Expand Down
70 changes: 40 additions & 30 deletions mujoco_warp/_src/collision_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@


@wp.func
def _geom_filter(m: Model, geom1: int, geom2: int, filterparent: bool) -> bool:
def _geom_filter(
m: Model, geom1: int, geom2: int, filterparent: bool, worldid: int
) -> bool:
bodyid1 = m.geom_bodyid[geom1]
bodyid2 = m.geom_bodyid[geom2]
contype1 = m.geom_contype[geom1]
contype2 = m.geom_contype[geom2]
conaffinity1 = m.geom_conaffinity[geom1]
conaffinity2 = m.geom_conaffinity[geom2]
contype1 = m.geom_contype[worldid, geom1]
contype2 = m.geom_contype[worldid, geom2]
conaffinity1 = m.geom_conaffinity[worldid, geom1]
conaffinity2 = m.geom_conaffinity[worldid, geom2]
weldid1 = m.body_weldid[bodyid1]
weldid2 = m.body_weldid[bodyid2]
weld_parentid1 = m.body_weldid[m.body_parentid[weldid1]]
Expand Down Expand Up @@ -82,7 +84,7 @@ def broadphase_project_spheres_onto_sweep_direction_kernel(
if r == 0.0:
# current geom is a plane
r = 1000000000.0
sphere_radius = r + m.geom_margin[i]
sphere_radius = r + m.geom_margin[worldid, i]

center = wp.dot(direction, c)
f = center - sphere_radius
Expand Down Expand Up @@ -146,7 +148,7 @@ def reorder_bounding_spheres_kernel(
# Get the bounding volume
c = d.geom_xpos[worldid, mapped]
r = m.geom_rbound[mapped]
margin = m.geom_margin[mapped]
margin = m.geom_margin[worldid, mapped]

# Reorder the box into the sorted array
if r == 0.0:
Expand Down Expand Up @@ -291,7 +293,7 @@ def sap_broadphase_kernel(m: Model, d: Data, num_threads: int, filter_parent: bo
idx1 = d.sap_sort_index[worldid, i]
idx2 = d.sap_sort_index[worldid, j]

if not _geom_filter(m, idx1, idx2, filter_parent):
if not _geom_filter(m, idx1, idx2, filter_parent, worldid):
threadId += num_threads
continue

Expand All @@ -317,17 +319,19 @@ def get_contact_solver_params_kernel(
g1 = geoms.x
g2 = geoms.y

margin = wp.max(m.geom_margin[g1], m.geom_margin[g2])
gap = wp.max(m.geom_gap[g1], m.geom_gap[g2])
solmix1 = m.geom_solmix[g1]
solmix2 = m.geom_solmix[g2]
worldid = d.contact.worldid[tid]

margin = wp.max(m.geom_margin[worldid, g1], m.geom_margin[worldid, g2])
gap = wp.max(m.geom_gap[worldid, g1], m.geom_gap[worldid, g2])
solmix1 = m.geom_solmix[worldid, g1]
solmix2 = m.geom_solmix[worldid, g2]
mix = solmix1 / (solmix1 + solmix2)
mix = wp.where((solmix1 < MJ_MINVAL) and (solmix2 < MJ_MINVAL), 0.5, mix)
mix = wp.where((solmix1 < MJ_MINVAL) and (solmix2 >= MJ_MINVAL), 0.0, mix)
mix = wp.where((solmix1 >= MJ_MINVAL) and (solmix2 < MJ_MINVAL), 1.0, mix)

p1 = m.geom_priority[g1]
p2 = m.geom_priority[g2]
p1 = m.geom_priority[worldid, g1]
p2 = m.geom_priority[worldid, g2]
mix = wp.where(p1 == p2, mix, wp.where(p1 > p2, 1.0, 0.0))

condim1 = m.geom_condim[g1]
Expand All @@ -337,15 +341,21 @@ def get_contact_solver_params_kernel(
)
d.contact.dim[tid] = condim

if m.geom_solref[g1].x > 0.0 and m.geom_solref[g2].x > 0.0:
d.contact.solref[tid] = mix * m.geom_solref[g1] + (1.0 - mix) * m.geom_solref[g2]
if m.geom_solref[worldid, g1].x > 0.0 and m.geom_solref[worldid, g2].x > 0.0:
d.contact.solref[tid] = (
mix * m.geom_solref[worldid, g1] + (1.0 - mix) * m.geom_solref[worldid, g2]
)
else:
d.contact.solref[tid] = wp.min(m.geom_solref[g1], m.geom_solref[g2])
d.contact.solref[tid] = wp.min(
m.geom_solref[worldid, g1], m.geom_solref[worldid, g2]
)
d.contact.includemargin[tid] = margin - gap
friction_ = wp.max(m.geom_friction[g1], m.geom_friction[g2])
friction_ = wp.max(m.geom_friction[worldid, g1], m.geom_friction[worldid, g2])
friction5 = vec5(friction_[0], friction_[0], friction_[1], friction_[2], friction_[2])
d.contact.friction[tid] = friction5
d.contact.solimp[tid] = mix * m.geom_solimp[g1] + (1.0 - mix) * m.geom_solimp[g2]
d.contact.solimp[tid] = (
mix * m.geom_solimp[worldid, g1] + (1.0 - mix) * m.geom_solimp[worldid, g2]
)


def sap_broadphase(m: Model, d: Data):
Expand All @@ -357,7 +367,7 @@ def sap_broadphase(m: Model, d: Data):

wp.launch(
kernel=broadphase_project_spheres_onto_sweep_direction_kernel,
dim=(d.nworld, m.ngeom),
dim=(m.nworld, m.ngeom),
inputs=[m, d, direction],
)

Expand All @@ -367,19 +377,19 @@ def sap_broadphase(m: Model, d: Data):
if tile_sort_available:
segmented_sort_kernel = create_segmented_sort_kernel(m.ngeom)
wp.launch_tiled(
kernel=segmented_sort_kernel, dim=(d.nworld), inputs=[m, d], block_dim=128
kernel=segmented_sort_kernel, dim=(m.nworld), inputs=[m, d], block_dim=128
)
print("tile sort available")
elif segmented_sort_available:
wp.utils.segmented_sort_pairs(
d.sap_projection_lower,
d.sap_sort_index,
m.ngeom * d.nworld,
m.ngeom * m.nworld,
d.sap_segment_index,
)
else:
# Sort each world's segment separately
for world_id in range(d.nworld):
for world_id in range(m.nworld):
start_idx = world_id * m.ngeom

# Create temporary arrays for sorting
Expand Down Expand Up @@ -431,21 +441,21 @@ def sap_broadphase(m: Model, d: Data):

wp.launch(
kernel=reorder_bounding_spheres_kernel,
dim=(d.nworld, m.ngeom),
dim=(m.nworld, m.ngeom),
inputs=[m, d],
)

wp.launch(
kernel=sap_broadphase_prepare_kernel,
dim=(d.nworld, m.ngeom),
dim=(m.nworld, m.ngeom),
inputs=[m, d],
)

# The scan (scan = cumulative sum, either inclusive or exclusive depending on the last argument) is used for load balancing among the threads
wp.utils.array_scan(d.sap_range.reshape(-1), d.sap_cumulative_sum, True)

# Estimate how many overlap checks need to be done - assumes each box has to be compared to 5 other boxes (and batched over all worlds)
num_sweep_threads = 5 * d.nworld * m.ngeom
num_sweep_threads = 5 * m.nworld * m.ngeom
filter_parent = not m.opt.disableflags & DisableBit.FILTERPARENT.value
wp.launch(
kernel=sap_broadphase_kernel,
Expand Down Expand Up @@ -478,8 +488,8 @@ def _nxn_broadphase(m: Model, d: Data):
+ (m.ngeom - geom1) * ((m.ngeom - geom1) - 1) // 2
)

margin1 = m.geom_margin[geom1]
margin2 = m.geom_margin[geom2]
margin1 = m.geom_margin[worldid, geom1]
margin2 = m.geom_margin[worldid, geom2]
pos1 = d.geom_xpos[worldid, geom1]
pos2 = d.geom_xpos[worldid, geom2]
size1 = m.geom_rbound[geom1]
Expand All @@ -503,13 +513,13 @@ def _nxn_broadphase(m: Model, d: Data):
dist = wp.dot(-dif, wp.vec3(xmat2[0, 2], xmat2[1, 2], xmat2[2, 2]))
bounds_filter = dist <= bound

geom_filter = _geom_filter(m, geom1, geom2, filterparent)
geom_filter = _geom_filter(m, geom1, geom2, filterparent, worldid)

if bounds_filter and geom_filter:
_add_geom_pair(m, d, geom1, geom2, worldid)

wp.launch(
_nxn_broadphase, dim=(d.nworld, m.ngeom * (m.ngeom - 1) // 2), inputs=[m, d]
_nxn_broadphase, dim=(m.nworld, m.ngeom * (m.ngeom - 1) // 2), inputs=[m, d]
)


Expand Down
2 changes: 1 addition & 1 deletion mujoco_warp/_src/collision_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def _primitive_narrowphase(
geom1 = _geom(g1, m, d.geom_xpos[worldid], d.geom_xmat[worldid])
geom2 = _geom(g2, m, d.geom_xpos[worldid], d.geom_xmat[worldid])

margin = wp.max(m.geom_margin[g1], m.geom_margin[g2])
margin = wp.max(m.geom_margin[worldid, g1], m.geom_margin[worldid, g2])

# TODO(team): static loop unrolling to remove unnecessary branching
if type1 == int(GeomType.PLANE.value) and type2 == int(GeomType.SPHERE.value):
Expand Down
22 changes: 13 additions & 9 deletions mujoco_warp/_src/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def _efc_limit_slide_hinge(
jntid = m.jnt_limited_slide_hinge_adr[jntlimitedid]

qpos = d.qpos[worldid, m.jnt_qposadr[jntid]]
jnt_range = m.jnt_range[jntid]
jnt_range = m.jnt_range[worldid, jntid]
dist_min, dist_max = qpos - jnt_range[0], jnt_range[1] - qpos
pos = wp.min(dist_min, dist_max) - m.jnt_margin[jntid]
pos = wp.min(dist_min, dist_max) - m.jnt_margin[worldid, jntid]
active = pos < 0

if active:
Expand All @@ -130,10 +130,10 @@ def _efc_limit_slide_hinge(
efcid,
pos,
pos,
m.dof_invweight0[dofadr],
m.jnt_solref[jntid],
m.jnt_solimp[jntid],
m.jnt_margin[jntid],
m.dof_invweight0[worldid, dofadr],
m.jnt_solref[worldid, jntid],
m.jnt_solimp[worldid, jntid],
m.jnt_margin[worldid, jntid],
refsafe,
Jqvel,
)
Expand Down Expand Up @@ -179,7 +179,9 @@ def _efc_contact_pyramidal(
frame = d.contact.frame[conid]

# pyramidal has common invweight across all edges
invweight = m.body_invweight0[body1, 0] + m.body_invweight0[body2, 0]
invweight = (
m.body_invweight0[worldid, body1, 0] + m.body_invweight0[worldid, body2, 0]
)

if condim > 1:
dimid2 = dimid / 2 + 1
Expand Down Expand Up @@ -285,7 +287,9 @@ def _efc_contact_elliptic(
d.efc.J[efcid, i] = J
Jqvel += J * d.qvel[worldid, i]

invweight = m.body_invweight0[body1, 0] + m.body_invweight0[body2, 0]
invweight = (
m.body_invweight0[worldid, body1, 0] + m.body_invweight0[worldid, body2, 0]
)

ref = d.contact.solref[conid]
pos_aref = pos
Expand Down Expand Up @@ -340,7 +344,7 @@ def make_constraint(m: types.Model, d: types.Data):
):
wp.launch(
_efc_limit_slide_hinge,
dim=(d.nworld, m.jnt_limited_slide_hinge_adr.size),
dim=(m.nworld, m.jnt_limited_slide_hinge_adr.size),
inputs=[m, d, refsafe],
)

Expand Down
Loading