Skip to content

Device Management in Multi-GPU systems, v2 #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
beb5034
io.py with scopedDevice
adenzler-nvidia Apr 24, 2025
ab1ec4c
collision
adenzler-nvidia Apr 24, 2025
dafb0b3
nxn broadphase
adenzler-nvidia Apr 24, 2025
e7e9ab6
sap broadphase
adenzler-nvidia Apr 24, 2025
2fa3cea
primitive narrowphase
adenzler-nvidia Apr 24, 2025
8591d76
make_constraint
adenzler-nvidia Apr 24, 2025
a70ece1
euler
adenzler-nvidia Apr 24, 2025
8119575
forward and fwd_acceleration
adenzler-nvidia Apr 24, 2025
b616e75
fwd_actuation
adenzler-nvidia Apr 24, 2025
0968897
fwd_position
adenzler-nvidia Apr 24, 2025
2b86709
fwd_velocity
adenzler-nvidia Apr 24, 2025
c0ac2ec
implicit
adenzler-nvidia Apr 24, 2025
35f905a
rungeKutta4
adenzler-nvidia Apr 24, 2025
667451b
step
adenzler-nvidia Apr 24, 2025
9d226f8
passive
adenzler-nvidia Apr 24, 2025
1aade98
sensor_acc
adenzler-nvidia Apr 24, 2025
ac45bbb
sensor_pos
adenzler-nvidia Apr 24, 2025
1816393
sensor_vel
adenzler-nvidia Apr 24, 2025
2060b39
com_pos
adenzler-nvidia Apr 24, 2025
b0decf3
com_vel
adenzler-nvidia Apr 24, 2025
62f0807
crb
adenzler-nvidia Apr 24, 2025
15fff11
factor_m
adenzler-nvidia Apr 24, 2025
7f08964
kinematics
adenzler-nvidia Apr 24, 2025
fcb8455
rne
adenzler-nvidia Apr 24, 2025
21339e9
rne_postconstraint
adenzler-nvidia Apr 24, 2025
daa52a3
solve_m
adenzler-nvidia Apr 24, 2025
ebecd76
subtree_vel
adenzler-nvidia Apr 24, 2025
56e8840
tendon
adenzler-nvidia Apr 24, 2025
ea4f0bc
transmission
adenzler-nvidia Apr 24, 2025
69e84c8
solve
adenzler-nvidia Apr 24, 2025
af663bf
remove warp function from API
adenzler-nvidia Apr 24, 2025
98c4760
mul_m
adenzler-nvidia Apr 24, 2025
d750cf8
xfrc_accumulate
adenzler-nvidia Apr 24, 2025
417edd3
benchmark
adenzler-nvidia Apr 24, 2025
4c1feee
fix missing else in euler
adenzler-nvidia Apr 24, 2025
f8635a6
add test
adenzler-nvidia Apr 24, 2025
e2f9f77
formatting
adenzler-nvidia Apr 24, 2025
012a2a4
ruff fixes
adenzler-nvidia Apr 24, 2025
aca1e8b
Merge branch 'main' into dev/adenzler/device-management-v2
adenzler-nvidia Apr 24, 2025
82f7350
Merge branch 'main' into dev/adenzler/device-management-v2
adenzler-nvidia Apr 24, 2025
6606ff4
ruff format
adenzler-nvidia Apr 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion mujoco_warp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from ._src.smooth import tendon as tendon
from ._src.smooth import transmission as transmission
from ._src.solver import solve as solve
from ._src.support import contact_force as contact_force
from ._src.support import is_sparse as is_sparse
from ._src.support import mul_m as mul_m
from ._src.support import xfrc_accumulate as xfrc_accumulate
Expand Down
138 changes: 70 additions & 68 deletions mujoco_warp/_src/collision_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,97 +167,99 @@ def _sap_broadphase(m: Model, d: Data, nsweep: int, filterparent: bool):
def sap_broadphase(m: Model, d: Data):
"""Broadphase collision detection via sweep-and-prune."""

nworldgeom = d.nworld * m.ngeom
with wp.ScopedDevice(m.qpos0.device):
nworldgeom = d.nworld * m.ngeom

# TODO(team): direction
# TODO(team): direction

# random fixed direction
direction = wp.vec3(0.5935, 0.7790, 0.1235)
direction = wp.normalize(direction)
# random fixed direction
direction = wp.vec3(0.5935, 0.7790, 0.1235)
direction = wp.normalize(direction)

wp.launch(
kernel=_sap_project,
dim=(d.nworld, m.ngeom),
inputs=[m, d, direction],
)
wp.launch(
kernel=_sap_project,
dim=(d.nworld, m.ngeom),
inputs=[m, d, direction],
)

# TODO(team): tile sort
# TODO(team): tile sort

wp.utils.segmented_sort_pairs(
d.sap_projection_lower,
d.sap_sort_index,
nworldgeom,
d.sap_segment_index,
)
wp.utils.segmented_sort_pairs(
d.sap_projection_lower,
d.sap_sort_index,
nworldgeom,
d.sap_segment_index,
)

wp.launch(
kernel=_sap_range,
dim=(d.nworld, m.ngeom),
inputs=[m, d],
)
wp.launch(
kernel=_sap_range,
dim=(d.nworld, m.ngeom),
inputs=[m, d],
)

# scan is used for load balancing among the threads
wp.utils.array_scan(d.sap_range.reshape(-1), d.sap_cumulative_sum, True)
# scan is used for load balancing among the threads
wp.utils.array_scan(d.sap_range.reshape(-1), d.sap_cumulative_sum, True)

# estimate number of overlap checks - assumes each geom has 5 other geoms (batched over all worlds)
nsweep = 5 * nworldgeom
filterparent = not m.opt.disableflags & DisableBit.FILTERPARENT.value
wp.launch(
kernel=_sap_broadphase,
dim=nsweep,
inputs=[m, d, nsweep, filterparent],
)
# estimate number of overlap checks - assumes each geom has 5 other geoms (batched over all worlds)
nsweep = 5 * nworldgeom
filterparent = not m.opt.disableflags & DisableBit.FILTERPARENT.value
wp.launch(
kernel=_sap_broadphase,
dim=nsweep,
inputs=[m, d, nsweep, filterparent],
)


def nxn_broadphase(m: Model, d: Data):
"""Broadphase collision detective via brute-force search."""

@wp.kernel
def _nxn_broadphase(m: Model, d: Data):
worldid, elementid = wp.tid()
with wp.ScopedDevice(m.qpos0.device):

# check for valid geom pair
if m.nxn_pairid[elementid] < -1:
return
@wp.kernel
def _nxn_broadphase(m: Model, d: Data):
worldid, elementid = wp.tid()

geom = m.nxn_geom_pair[elementid]
geom1 = geom[0]
geom2 = geom[1]
# check for valid geom pair
if m.nxn_pairid[elementid] < -1:
return

if _sphere_filter(m, d, geom1, geom2, worldid):
_add_geom_pair(m, d, geom1, geom2, worldid, elementid)
geom = m.nxn_geom_pair[elementid]
geom1 = geom[0]
geom2 = geom[1]

if m.nxn_geom_pair.shape[0]:
wp.launch(_nxn_broadphase, dim=(d.nworld, m.nxn_geom_pair.shape[0]), inputs=[m, d])
if _sphere_filter(m, d, geom1, geom2, worldid):
_add_geom_pair(m, d, geom1, geom2, worldid, elementid)

if m.nxn_geom_pair.shape[0]:
wp.launch(
_nxn_broadphase, dim=(d.nworld, m.nxn_geom_pair.shape[0]), inputs=[m, d]
)


@event_scope
def collision(m: Model, d: Data):
"""Collision detection."""

# AD: based on engine_collision_driver.py in Eric's warp fork/mjx-collisions-dev
# which is further based on the CUDA code here:
# https://github.com/btaba/mujoco/blob/warp-collisions/mjx/mujoco/mjx/_src/cuda/engine_collision_driver.cu.cc#L458-L583

d.ncollision.zero_()
d.ncon.zero_()
with wp.ScopedDevice(m.qpos0.device):
d.ncollision.zero_()
d.ncon.zero_()

if d.nconmax == 0:
return
if d.nconmax == 0:
return

dsbl_flgs = m.opt.disableflags
if (dsbl_flgs & DisableBit.CONSTRAINT) | (dsbl_flgs & DisableBit.CONTACT):
return
dsbl_flgs = m.opt.disableflags
if (dsbl_flgs & DisableBit.CONSTRAINT) | (dsbl_flgs & DisableBit.CONTACT):
return

# TODO(team): determine ngeom to switch from n^2 to sap
if m.ngeom <= 100:
nxn_broadphase(m, d)
else:
sap_broadphase(m, d)

# TODO(team): we should reject far-away contacts in the narrowphase instead of constraint
# partitioning because we can move some pressure of the atomics
# TODO(team) switch between collision functions and GJK/EPA here
gjk_narrowphase(m, d)
primitive_narrowphase(m, d)
box_box_narrowphase(m, d)
# TODO(team): determine ngeom to switch from n^2 to sap
if m.ngeom <= 100:
nxn_broadphase(m, d)
else:
sap_broadphase(m, d)

# TODO(team): we should reject far-away contacts in the narrowphase instead of constraint
# partitioning because we can move some pressure of the atomics
# TODO(team) switch between collision functions and GJK/EPA here
gjk_narrowphase(m, d)
primitive_narrowphase(m, d)
box_box_narrowphase(m, d)
7 changes: 4 additions & 3 deletions mujoco_warp/_src/collision_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,7 @@ def _primitive_narrowphase(


def primitive_narrowphase(m: Model, d: Data):
# we need to figure out how to keep the overhead of this small - not launching anything
# for pair types without collisions, as well as updating the launch dimensions.
wp.launch(_primitive_narrowphase, dim=d.nconmax, inputs=[m, d])
with wp.ScopedDevice(m.qpos0.device):
# we need to figure out how to keep the overhead of this small - not launching anything
# for pair types without collisions, as well as updating the launch dimensions.
wp.launch(_primitive_narrowphase, dim=d.nconmax, inputs=[m, d])
125 changes: 63 additions & 62 deletions mujoco_warp/_src/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,79 +704,80 @@ def _update_nefc(d: types.Data):
def make_constraint(m: types.Model, d: types.Data):
"""Creates constraint jacobians and other supporting data."""

d.ne.zero_()
d.ne_connect.zero_()
d.ne_weld.zero_()
d.ne_jnt.zero_()
d.nefc.zero_()
d.nf.zero_()
d.nl.zero_()

if not (m.opt.disableflags & types.DisableBit.CONSTRAINT.value):
d.efc.J.zero_()

if not (m.opt.disableflags & types.DisableBit.EQUALITY.value):
wp.launch(
_efc_equality_connect,
dim=(d.nworld, m.eq_connect_adr.size),
inputs=[m, d],
)
wp.launch(
_efc_equality_weld,
dim=(d.nworld, m.eq_wld_adr.size),
inputs=[m, d],
)
wp.launch(
_efc_equality_joint,
dim=(d.nworld, m.eq_jnt_adr.size),
inputs=[m, d],
)

wp.launch(_num_equality, dim=(1,), inputs=[d])

if not (m.opt.disableflags & types.DisableBit.FRICTIONLOSS.value):
wp.launch(
_efc_friction,
dim=(d.nworld, m.nv),
inputs=[m, d],
)

# limit
if not (m.opt.disableflags & types.DisableBit.LIMIT.value):
limit_ball = m.jnt_limited_ball_adr.size > 0
if limit_ball:
with wp.ScopedDevice(m.qpos0.device):
d.ne.zero_()
d.ne_connect.zero_()
d.ne_weld.zero_()
d.ne_jnt.zero_()
d.nefc.zero_()
d.nf.zero_()
d.nl.zero_()

if not (m.opt.disableflags & types.DisableBit.CONSTRAINT.value):
d.efc.J.zero_()

if not (m.opt.disableflags & types.DisableBit.EQUALITY.value):
wp.launch(
_efc_limit_ball,
dim=(d.nworld, m.jnt_limited_ball_adr.size),
_efc_equality_connect,
dim=(d.nworld, m.eq_connect_adr.size),
inputs=[m, d],
)

limit_slide_hinge = m.jnt_limited_slide_hinge_adr.size > 0
if limit_slide_hinge:
wp.launch(
_efc_limit_slide_hinge,
dim=(d.nworld, m.jnt_limited_slide_hinge_adr.size),
_efc_equality_weld,
dim=(d.nworld, m.eq_wld_adr.size),
inputs=[m, d],
)

limit_tendon = m.tendon_limited_adr.size > 0
if limit_tendon:
wp.launch(
_efc_limit_tendon,
dim=(d.nworld, m.tendon_limited_adr.size),
_efc_equality_joint,
dim=(d.nworld, m.eq_jnt_adr.size),
inputs=[m, d],
)

if limit_ball or limit_slide_hinge or limit_tendon:
wp.launch(_update_nefc, dim=(1,), inputs=[d])
wp.launch(_num_equality, dim=(1,), inputs=[d])

# contact
if not (m.opt.disableflags & types.DisableBit.CONTACT.value):
if m.opt.cone == types.ConeType.PYRAMIDAL.value:
if not (m.opt.disableflags & types.DisableBit.FRICTIONLOSS.value):
wp.launch(
_efc_contact_pyramidal,
dim=(d.nconmax, 2 * (m.condim_max - 1) if m.condim_max > 1 else 1),
_efc_friction,
dim=(d.nworld, m.nv),
inputs=[m, d],
)
elif m.opt.cone == types.ConeType.ELLIPTIC.value:
wp.launch(_efc_contact_elliptic, dim=(d.nconmax, m.condim_max), inputs=[m, d])

# limit
if not (m.opt.disableflags & types.DisableBit.LIMIT.value):
limit_ball = m.jnt_limited_ball_adr.size > 0
if limit_ball:
wp.launch(
_efc_limit_ball,
dim=(d.nworld, m.jnt_limited_ball_adr.size),
inputs=[m, d],
)

limit_slide_hinge = m.jnt_limited_slide_hinge_adr.size > 0
if limit_slide_hinge:
wp.launch(
_efc_limit_slide_hinge,
dim=(d.nworld, m.jnt_limited_slide_hinge_adr.size),
inputs=[m, d],
)

limit_tendon = m.tendon_limited_adr.size > 0
if limit_tendon:
wp.launch(
_efc_limit_tendon,
dim=(d.nworld, m.tendon_limited_adr.size),
inputs=[m, d],
)

if limit_ball or limit_slide_hinge or limit_tendon:
wp.launch(_update_nefc, dim=(1,), inputs=[d])

# contact
if not (m.opt.disableflags & types.DisableBit.CONTACT.value):
if m.opt.cone == types.ConeType.PYRAMIDAL.value:
wp.launch(
_efc_contact_pyramidal,
dim=(d.nconmax, 2 * (m.condim_max - 1) if m.condim_max > 1 else 1),
inputs=[m, d],
)
elif m.opt.cone == types.ConeType.ELLIPTIC.value:
wp.launch(_efc_contact_elliptic, dim=(d.nconmax, m.condim_max), inputs=[m, d])
Loading